{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b4f70fcb-4ddc-4373-b01d-6199057beb48",
   "metadata": {},
   "source": [
    "Performs few-shot learning by prompting a GPT-2 Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3cc3e552-63e4-48f6-92c5-5161fefc5373",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce501b5a-f260-4abc-89c0-d3a44c820c7f",
   "metadata": {},
   "source": [
    "# Pre-Processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "314ed453-2695-412a-9cc4-e73aed6f1972",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_length = 400"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3e1ba918-c84f-4294-8edd-e81b7c617e0f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>review</th>\n",
       "      <th>sentiment</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>One of the other reviewers has mentioned that ...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>A wonderful little production. &lt;br /&gt;&lt;br /&gt;The...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>I thought this was a wonderful way to spend ti...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Basically there's a family where a little boy ...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Petter Mattei's \"Love in the Time of Money\" is...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                              review  sentiment\n",
       "0  One of the other reviewers has mentioned that ...          1\n",
       "1  A wonderful little production. <br /><br />The...          1\n",
       "2  I thought this was a wonderful way to spend ti...          1\n",
       "3  Basically there's a family where a little boy ...          0\n",
       "4  Petter Mattei's \"Love in the Time of Money\" is...          1"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('IMDB Dataset.csv')\n",
    "df['sentiment'] = (df['sentiment'] == 'positive').astype(int)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "47bee21b-f3fa-4a01-8907-b8e0c85aee47",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0        307\n",
       "1        162\n",
       "2        166\n",
       "3        138\n",
       "4        230\n",
       "        ... \n",
       "49995    194\n",
       "49996    112\n",
       "49997    230\n",
       "49998    212\n",
       "49999    129\n",
       "Name: review, Length: 50000, dtype: int64"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df['review'].apply(lambda x: len(x.split()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a37d8308-ead5-48ee-8046-7174c401555f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "average word count: 231.15694\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZIAAAEGCAYAAABPdROvAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAbOElEQVR4nO3df5QV5Z3n8fdHiIiJKEjrIA3TaJhkkGRG6XUxzklMiEc2uMI4cbadJJIMM2RcNjHj7EkgzsbNyeGMbqKObkYyrBghyYqEcZTEOOpgNDsJ/mhjFFAJbSDagQiJWWXcFQW/+0c9vZbN7dvVXX3v5XZ/Xufcc6u+VU/V81x+fM9TT9VTigjMzMwG64hGV8DMzJqbE4mZmZXiRGJmZqU4kZiZWSlOJGZmVsroRleg3iZOnBhtbW2NroaZWVN59NFHfxURLZW2jbhE0tbWRmdnZ6OrYWbWVCT9vK9tvrRlZmalOJGYmVkpTiRmZlaKE4mZmZXiRGJmZqU4kZiZWSlOJGZmVooTiZmZlVKzRCLpJkl7JG2psO0/SwpJE3OxZZK6JG2TdG4uPkvS5rTteklK8TGSbk3xhyS11aotZmbWt1o+2X4z8FVgTT4oaQpwDvBsLjYD6ABOBU4C/lnS70TEQWAFsBh4EPgeMBe4C1gE/CYi3i6pA7gK+A81bE9DtS29c9Bld145bwhrYmb2ZjXrkUTED4AXKmy6FvgskH8143xgbUTsj4gdQBdwhqRJwLiI2BTZqxzXAAtyZVan5fXAnJ7eipmZ1U9dx0gknQ/8IiIe77VpMvBcbr07xSan5d7xN5WJiAPAi8DxfZx3saROSZ179+4t3Q4zM3tD3RKJpKOBy4EvVNpcIRZV4tXKHBqMWBkR7RHR3tJScfJKMzMbpHr2SE4BpgGPS9oJtAI/lvRbZD2NKbl9W4FdKd5aIU6+jKTRwLFUvpRmZmY1VLdEEhGbI+KEiGiLiDayRHB6RPwS2AB0pDuxpgHTgYcjYjewT9LsNP5xMXBHOuQGYGFa/jBwXxpHMTOzOqrl7b+3AJuAd0jqlrSor30jYiuwDngS+CdgSbpjC+AS4EayAfhnyO7YAlgFHC+pC7gMWFqThpiZWVU1u/03Ii7qZ3tbr/XlwPIK+3UCMyvEXwEuLFdLMzMry0+2m5lZKU4kZmZWihOJmZmV4kRiZmalOJGYmVkpTiRmZlaKE4mZmZXiRGJmZqU4kZiZWSlOJGZmVooTiZmZleJEYmZmpTiRmJlZKU4kZmZWihOJmZmV4kRiZmalOJGYmVkpTiRmZlaKE4mZmZXiRGJmZqXULJFIuknSHklbcrEvS3pa0hOS/lHScbltyyR1Sdom6dxcfJakzWnb9ZKU4mMk3ZriD0lqq1VbzMysb7XskdwMzO0VuxeYGRHvBn4KLAOQNAPoAE5NZW6QNCqVWQEsBqanT88xFwG/iYi3A9cCV9WsJWZm1qeaJZKI+AHwQq/YPRFxIK0+CLSm5fnA2ojYHxE7gC7gDEmTgHERsSkiAlgDLMiVWZ2W1wNzenorZmZWP40cI/lT4K60PBl4LretO8Ump+Xe8TeVScnpReD4SieStFhSp6TOvXv3DlkDzMysQYlE0uXAAeBbPaEKu0WVeLUyhwYjVkZEe0S0t7S0DLS6ZmZWRd0TiaSFwHnAR9LlKsh6GlNyu7UCu1K8tUL8TWUkjQaOpdelNDMzq726JhJJc4HPAedHxP/JbdoAdKQ7saaRDao/HBG7gX2SZqfxj4uBO3JlFqblDwP35RKTmZnVyehaHVjSLcDZwERJ3cAVZHdpjQHuTePiD0bEX0TEVknrgCfJLnktiYiD6VCXkN0BNpZsTKVnXGUV8A1JXWQ9kY5atcXMzPpWs0QSERdVCK+qsv9yYHmFeCcws0L8FeDCMnU0M7Py/GS7mZmV4kRiZmalOJGYmVkpTiRmZlaKE4mZmZXiRGJmZqU4kZiZWSlOJGZmVooTiZmZleJEYmZmpTiRmJlZKU4kZmZWihOJmZmV4kRiZmalOJGYmVkpTiRmZlaKE4mZmZXiRGJmZqU4kZiZWSlOJGZmVkrNEomkmyTtkbQlF5sg6V5J29P3+Ny2ZZK6JG2TdG4uPkvS5rTteklK8TGSbk3xhyS11aotZmbWt1r2SG4G5vaKLQU2RsR0YGNaR9IMoAM4NZW5QdKoVGYFsBiYnj49x1wE/CYi3g5cC1xVs5aYmVmfBpRIJI2X9O4i+0bED4AXeoXnA6vT8mpgQS6+NiL2R8QOoAs4Q9IkYFxEbIqIANb0KtNzrPXAnJ7eipmZ1U+/iUTS/ZLGSZoAPA58XdI1gzzfiRGxGyB9n5Dik4Hncvt1p9jktNw7/qYyEXEAeBE4fpD1MjOzQSrSIzk2Il4CLgC+HhGzgA8OcT0q9SSiSrxamUMPLi2W1Cmpc+/evYOsopmZVVIkkYxOl5j+GPhuyfM9n45F+t6T4t3AlNx+rcCuFG+tEH9TGUmjgWM59FIaABGxMiLaI6K9paWlZBPMzCyvSCL5InA30BURj0g6Gdg+yPNtABam5YXAHbl4R7oTaxrZoPrD6fLXPkmz0/jHxb3K9Bzrw8B9aRzFzMzqaHSBfXZHxP8fYI+InxUZI5F0C3A2MFFSN3AFcCWwTtIi4FngwnTMrZLWAU8CB4AlEXEwHeoSsjvAxgJ3pQ/AKuAbkrrIeiIdBdpiZmZDrEgi+e/A6QVibxIRF/WxaU4f+y8HlleIdwIzK8RfISUiMzNrnD4TiaQzgfcALZIuy20aB4yqXMrMzEaaaj2SI4G3pX2OycVfIhuTMDMz6zuRRMQDwAOSbo2Ip/PbJE2sec3MzKwpFLlra52k2T0rkv4I+FHtqmRmZs2kyGD7R4CbJN0PnET29PgHalkpMzNrHv0mkojYLGk58A1gH/DeiOjup5gdRtqW3jnosjuvnDeENTGz4ajfRCJpFXAK8G7gd4DvSPpqRPxdrStnZmaHvyJjJFuA90fEjoi4G5hNP8+QmJnZyNFvIomIa4GpknomanwV+EwtK2VmZs2jyDTyf072vo+/T6FW4PYa1snMzJpIkUtbS4CzyB5EJCK288Z7RMzMbIQrkkj2R8SrPStpynbPsmtmZkCxRPKApM8DYyWdA3wb+E5tq2VmZs2iSCJZCuwFNgOfBL4XEZfXtFZmZtY0ijzZ/qmIuA74Hz0BSZemmJmZjXBFeiQLK8Q+PsT1MDOzJlXtfSQXAX8CTJO0IbfpGODXta6YmZk1h2qXtn4E7AYmAlfn4vuAJ2pZKTMzax7V3kfyc+DnwJn1q46ZmTWbImMkZmZmfXIiMTOzUvpMJJI2pu+rhvqkkv5S0lZJWyTdIukoSRMk3Stpe/oen9t/maQuSdsknZuLz5K0OW27XpKGuq5mZlZdtR7JJEnvA86XdJqk0/OfwZ5Q0mTg00B7RMwERgEdZA8+boyI6cDGtI6kGWn7qcBc4AZJo9LhVgCLgenpM3ew9TIzs8GpdtfWF8j+M28Frum1LSj3ut3RZFOuvAYcDewClgFnp+2rgfuBzwHzgbURsR/YIakLOEPSTmBcRGwCkLQGWADcVaJeZmY2QNXu2loPrJf0XyLiS0N1woj4haSvAM8C/xe4JyLukXRiROxO++yW1DPD8GTgwdwhulPstbTcO34ISYvJei5MnTp1qJpiZmYUe7HVlySdL+kr6XNemROmsY/5wDTgJOCtkj5arUilalWJHxqMWBkR7RHR3tLSMtAqm5lZFUVebPU3wKXAk+lzaYoN1geBHRGxNyJeA24D3gM8L2lSOuckYE/avxuYkivfSnYprDst946bmVkdFbn9dx5wTkTcFBE3kQ1ozytxzmeB2ZKOTndZzQGeAjbwxrxeC4E70vIGoEPSGEnTyAbVH06XwfZJmp2Oc3GujJmZ1UmR2X8BjgNeSMvHljlhRDwkaT3wY+AA8BiwEngbsE7SIrJkc2Haf6ukdWS9oQPAkog4mA53CXAzMJZskN0D7WZmdVYkkfwN8Jik75ONS7yX7A6rQYuIK4AreoX3k/VOKu2/HFheId4JzCxTFzMzK6ffRBIRt0i6H/g3ZInkcxHxy1pXzMzMmkOhS1tpPGJDvzuamdmI47m2zMysFCcSMzMrpWoikXSEpC31qoyZmTWfqokkIl4HHpfkeUXMzKyiIoPtk4Ctkh4GXu4JRsT5NauVmZk1jSKJ5Is1r4WZmTWtIs+RPCDpt4HpEfHPko4me4eImZlZoUkb/xxYD/x9Ck0Gbq9hnczMrIkUuf13CXAW8BJARGwHTqhawszMRowiiWR/RLzasyJpNH2898PMzEaeIonkAUmfJ3s17jnAt4Hv1LZaZmbWLIokkqXAXmAz8Enge8Bf17JSZmbWPIrctfW6pNXAQ2SXtLZFhC9tmZkZUCCRSJoHfA14hmwa+WmSPhkRfomUmZkVeiDxauD9EdEFIOkU4E78NkIzM6PYGMmeniSS/AzYU6P6mJlZk+mzRyLpgrS4VdL3gHVkYyQXAo/UoW5mZtYEql3a+ve55eeB96XlvcD4mtXIzMyaSp+JJCI+UauTSjoOuBGYSdbL+VNgG3Ar0AbsBP44In6T9l8GLAIOAp+OiLtTfBZwMzCW7LbkS31HmZlZfRWZa2uapGsk3SZpQ8+n5HmvA/4pIt4J/B7wFNnzKhsjYjqwMa0jaQbQAZwKzAVukNQzaeQKYDEwPX3mlqyXmZkNUJG7tm4HVpE9zf562RNKGge8F/g4QJp+5VVJ84Gz026rgfuBzwHzgbURsR/YIakLOEPSTmBcRGxKx10DLMB3k5mZ1VWRRPJKRFw/hOc8mWyc5euSfg94FLgUODEidgNExG5JPRNDTgYezJXvTrHX0nLv+CEkLSbruTB1ql/2aGY2lIrc/nudpCsknSnp9J5PiXOOBk4HVkTEaWRvXVxaZX9ViEWV+KHBiJUR0R4R7S0tLQOtr5mZVVGkR/Iu4GPAB3jj0lak9cHoBroj4qG0vp4skTwvaVLqjUzijWdVuoEpufKtwK4Ub60QNzOzOirSI/lD4OSIeF9EvD99BptEiIhfAs9JekcKzQGeBDYAC1NsIXBHWt4AdEgaI2ka2aD6w+ky2D5JsyUJuDhXxszM6qRIj+Rx4DiG9mn2TwHfknQk2ZPynyBLauskLQKeJXvwkYjYKmkdWbI5ACyJiIPpOJfwxu2/d+GBdjOzuiuSSE4Enpb0CLC/JxgR5w/2pBHxE6C9wqY5fey/HFheId5J9iyKmZk1SJFEckXNa2FmZk2ryPtIHqhHRczMrDkVeR/JPt64rfZI4C3AyxExrpYVMzOz5lCkR3JMfl3SAuCMWlXIDi9tS+8cdNmdV84bwpqY2eGqyO2/bxIRtzP4Z0jMzGyYKXJp64Lc6hFkd1t5hl0zMwOK3bWVfy/JAbIp3ufXpDZmZtZ0ioyR1Oy9JGZm1vyqvWr3C1XKRUR8qQb1MTOzJlOtR/Jyhdhbyd5UeDzgRDJAZe6AMjM7XFV71e7VPcuSjiF7Z8gngLXA1X2VMzOzkaXqGImkCcBlwEfI3lp4es971M3MzKD6GMmXgQuAlcC7IuJf61YrMzNrGtUeSPwr4CTgr4Fdkl5Kn32SXqpP9czM7HBXbYxkwE+9m5nZyONkYWZmpTiRmJlZKU4kZmZWihOJmZmV4kRiZmalNCyRSBol6TFJ303rEyTdK2l7+h6f23eZpC5J2ySdm4vPkrQ5bbtekhrRFjOzkayRPZJLgady60uBjRExHdiY1pE0A+gATgXmAjdIGpXKrAAWA9PTZ259qm5mZj0akkgktQLzgBtz4flk07CQvhfk4msjYn9E7AC6gDMkTQLGRcSmiAhgTa6MmZnVSaN6JH8LfBZ4PRc7MSJ2A6TvE1J8MvBcbr/uFJuclnvHDyFpsaROSZ179+4dkgaYmVmm7olE0nnAnoh4tGiRCrGoEj80GLEyItojor2lpaXgac3MrIgir9odamcB50v6EHAUME7SN4HnJU2KiN3pstWetH83MCVXvhXYleKtFeJmZlZHde+RRMSyiGiNiDayQfT7IuKjwAZgYdptIXBHWt4AdEgaI2ka2aD6w+ny1z5Js9PdWhfnypiZWZ00okfSlyuBdZIWAc8CFwJExFZJ64AngQPAkog4mMpcAtwMjAXuSh8zM6ujhiaSiLgfuD8t/xqY08d+y4HlFeKdwMza1dDMzPpzOPVIbJgp+476nVfOG6KamFkteYoUMzMrxYnEzMxKcSIxM7NSnEjMzKwUJxIzMyvFicTMzEpxIjEzs1KcSMzMrBQnEjMzK8WJxMzMSnEiMTOzUpxIzMysFCcSMzMrxbP/2mGrzOzBnjnYrH7cIzEzs1KcSMzMrBQnEjMzK8WJxMzMSnEiMTOzUuqeSCRNkfR9SU9J2irp0hSfIOleSdvT9/hcmWWSuiRtk3RuLj5L0ua07XpJqnd7zMxGukb0SA4AfxURvwvMBpZImgEsBTZGxHRgY1onbesATgXmAjdIGpWOtQJYDExPn7n1bIiZmTUgkUTE7oj4cVreBzwFTAbmA6vTbquBBWl5PrA2IvZHxA6gCzhD0iRgXERsiogA1uTKmJlZnTR0jERSG3Aa8BBwYkTshizZACek3SYDz+WKdafY5LTcO25mZnXUsCfbJb0N+AfgMxHxUpXhjUobokq80rkWk10CY+rUqQOvrDUdPxVvVj8N6ZFIegtZEvlWRNyWws+ny1Wk7z0p3g1MyRVvBXaleGuF+CEiYmVEtEdEe0tLy9A1xMzMGnLXloBVwFMRcU1u0wZgYVpeCNyRi3dIGiNpGtmg+sPp8tc+SbPTMS/OlTEzszppxKWts4CPAZsl/STFPg9cCayTtAh4FrgQICK2SloHPEl2x9eSiDiYyl0C3AyMBe5KHzMzq6O6J5KI+Bcqj28AzOmjzHJgeYV4JzBz6GpnZmYD5SfbzcysFL+PZADK3AlkZjZcuUdiZmaluEdi1oufQTEbGPdIzMysFCcSMzMrxYnEzMxKcSIxM7NSPNhuNoQ8UG8jkXskZmZWihOJmZmV4kRiZmaleIzE7DBRdgoej7FYo7hHYmZmpbhHYjZM+I4xaxT3SMzMrBQnEjMzK8WXtszMl8WsFCcSMyvFSah+Dtc7+5xIzKxhDtf/GG1gnEjMrGm5N3R4aPpEImkucB0wCrgxIq5scJXMrAmU7Q0N1nBMYE2dSCSNAv4OOAfoBh6RtCEinmxszczMKmtUAqulZr/99wygKyJ+FhGvAmuB+Q2uk5nZiNLUPRJgMvBcbr0b+Le9d5K0GFicVv9V0rZBnGsi8KtBlGtmbvPIMBLbDCOw3bqqVJt/u68NzZ5IVCEWhwQiVgIrS51I6oyI9jLHaDZu88gwEtsMI7PdtWpzs1/a6gam5NZbgV0NqouZ2YjU7InkEWC6pGmSjgQ6gA0NrpOZ2YjS1Je2IuKApP8E3E12++9NEbG1RqcrdWmsSbnNI8NIbDOMzHbXpM2KOGRIwczMrLBmv7RlZmYN5kRiZmalOJEUIGmupG2SuiQtbXR9hoKkKZK+L+kpSVslXZriEyTdK2l7+h6fK7Ms/QbbJJ3buNqXI2mUpMckfTetj4Q2HydpvaSn05/5mcO93ZL+Mv3d3iLpFklHDbc2S7pJ0h5JW3KxAbdR0ixJm9O26yVVerSibxHhT5UP2SD+M8DJwJHA48CMRtdrCNo1CTg9LR8D/BSYAfw3YGmKLwWuSsszUtvHANPSbzKq0e0YZNsvA/4n8N20PhLavBr4s7R8JHDccG432cPKO4CxaX0d8PHh1mbgvcDpwJZcbMBtBB4GziR7Nu8u4N8NpB7ukfRvWE7DEhG7I+LHaXkf8BTZP775ZP/pkL4XpOX5wNqI2B8RO4Aust+mqUhqBeYBN+bCw73N48j+w1kFEBGvRsT/Zpi3m+yu1LGSRgNHkz1jNqzaHBE/AF7oFR5QGyVNAsZFxKbIssqaXJlCnEj6V2kalskNqktNSGoDTgMeAk6MiN2QJRvghLTbcPkd/hb4LPB6Ljbc23wysBf4erqkd6OktzKM2x0RvwC+AjwL7AZejIh7GMZtzhloGyen5d7xwpxI+ldoGpZmJeltwD8An4mIl6rtWiHWVL+DpPOAPRHxaNEiFWJN1eZkNNnljxURcRrwMtklj740fbvTuMB8sks4JwFvlfTRakUqxJqqzQX01cbSbXci6d+wnYZF0lvIksi3IuK2FH4+dXVJ33tSfDj8DmcB50vaSXaJ8gOSvsnwbjNk7eiOiIfS+nqyxDKc2/1BYEdE7I2I14DbgPcwvNvcY6Bt7E7LveOFOZH0b1hOw5LuylgFPBUR1+Q2bQAWpuWFwB25eIekMZKmAdPJBuiaRkQsi4jWiGgj+3O8LyI+yjBuM0BE/BJ4TtI7UmgO8CTDu93PArMlHZ3+rs8hGwcczm3uMaA2pstf+yTNTr/VxbkyxTT6roNm+AAfIrur6Rng8kbXZ4ja9Adk3dcngJ+kz4eA44GNwPb0PSFX5vL0G2xjgHd1HG4f4GzeuGtr2LcZ+H2gM/153w6MH+7tBr4IPA1sAb5BdrfSsGozcAvZGNBrZD2LRYNpI9CefqdngK+SZj0p+vEUKWZmVoovbZmZWSlOJGZmVooTiZmZleJEYmZmpTiRmJlZKU4kZn2QdK2kz+TW75Z0Y279akmXDfLYZ/fMPlxPaRbg/1jv89rw5kRi1rcfkT0NjaQjgInAqbnt7wF+WORAkkYNee0G5zjAicSGlBOJWd9+SEokZAlkC9kTwOMljQF+F3hM0pw0GeLm9H6IMQCSdkr6gqR/AS5U9l6bp9P6BZVOqOxdKV9Jx3pC0qdSvNo5Jqbldkn3p+X/mva7X9LPJH06neJK4BRJP5H05Rr8ZjYCjW50BcwOVxGxS9IBSVPJEsomsllRzwReJHtK/AjgZmBORPxU0hrgErJZhgFeiYg/kHQU2ZPGHyCbvvvWPk67mGyiwdMi4kB6SdFR/ZyjL+8E3k/2vpltklaQTdY4MyJ+fyC/hVk17pGYVdfTK+lJJJty6z8C3kE2OeBP0/6ryd790aMnYbwz7bc9sukkvtnH+T4IfC0iDgBExAsFztGXOyN798SvyCbuO7FAGbMBcyIxq65nnORdZJe2HiTrkfSMj/T3StKXc8tF5iNShf2qneMAb/w7PqrXtv255YP4CoTViBOJWXU/BM4DXoiIg6mHcBxZMtlENilgm6S3p/0/BjxQ4ThPA9MknZLWL+rjfPcAf5He6oekCf2cYycwKy3/UYH27CO71GU2ZJxIzKrbTHa31oO9Yi9GxK8i4hXgE8C3JW0me/Pi13ofJO23GLgzDbb/vI/z3Ug2BfoTkh4H/qSfc3wRuE7S/yLrdVQVEb8Gfihpiwfbbah49l8zMyvFPRIzMyvFicTMzEpxIjEzs1KcSMzMrBQnEjMzK8WJxMzMSnEiMTOzUv4fDKUZqUn2aRUAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.hist(df['review'].apply(lambda x: min(len(x.split()), 1000)), bins=20)\n",
    "plt.ylabel(\"Number of texts\")\n",
    "plt.xlabel(\"Word count\")\n",
    "print(f\"average word count: {np.mean(df['review'].apply(lambda x: len(x.split())))}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "7f23d60b-256c-463b-b324-f1615b3ab0ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Do a train-val-test split 80%-10%-10%\n",
    "X_train, X_test, y_train, y_test = train_test_split(np.array(df['review']), np.array(df['sentiment']), test_size=0.2)\n",
    "X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61a0ccec-df85-4bb4-84d3-71611c0ccb56",
   "metadata": {},
   "source": [
    "# GPT2 Fine Tuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "944e90fc-2a50-41fd-ab4a-3bdc3210c8b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load　GPT, BERT and support materials from huggingface\n",
    "# requires pip install transformers\n",
    "# if in jupyter notebook see here and you get an error mention ipython widgets see here: \n",
    "# https://stackoverflow.com/questions/53247985/tqdm-4-28-1-in-jupyter-notebook-intprogress-not-found-please-update-jupyter-an\n",
    "from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "03a2618c-5234-4df7-adf0-0aeb90ebcd7a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "94ba93b43e5a48e38ef48745ab42f735",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/1.04M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8649b3bff7aa4cc486c342db645b3b46",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cf6ff6ed126f4d21bef0cac81f581920",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/665 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cd5ee3fb3be6463e8bccf9651471a27e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/548M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-14-457ad9080eb6>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      2\u001b[0m \u001b[1;31m# Documentation for GPT: https://huggingface.co/transformers/model_doc/gpt2.html#gpt2lmheadmodel\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      3\u001b[0m \u001b[0mgpt_tokenizer\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mGPT2Tokenizer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_pretrained\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'gpt2'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 4\u001b[1;33m \u001b[0mgpt_model\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mGPT2LMHeadModel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_pretrained\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'gpt2'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32m~\\anaconda3\\envs\\fewrel2\\lib\\site-packages\\transformers\\modeling_utils.py\u001b[0m in \u001b[0;36mfrom_pretrained\u001b[1;34m(cls, pretrained_model_name_or_path, *model_args, **kwargs)\u001b[0m\n\u001b[0;32m    644\u001b[0m             \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    645\u001b[0m                 \u001b[1;31m# Load from URL or cache if already cached\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 646\u001b[1;33m                 resolved_archive_file = cached_path(\n\u001b[0m\u001b[0;32m    647\u001b[0m                     \u001b[0marchive_file\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    648\u001b[0m                     \u001b[0mcache_dir\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcache_dir\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\fewrel2\\lib\\site-packages\\transformers\\file_utils.py\u001b[0m in \u001b[0;36mcached_path\u001b[1;34m(url_or_filename, cache_dir, force_download, proxies, resume_download, user_agent, extract_compressed_file, force_extract, local_files_only)\u001b[0m\n\u001b[0;32m    562\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mis_remote_url\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0murl_or_filename\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    563\u001b[0m         \u001b[1;31m# URL, so get it from the cache (downloading if necessary)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 564\u001b[1;33m         output_path = get_from_cache(\n\u001b[0m\u001b[0;32m    565\u001b[0m             \u001b[0murl_or_filename\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    566\u001b[0m             \u001b[0mcache_dir\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcache_dir\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\fewrel2\\lib\\site-packages\\transformers\\file_utils.py\u001b[0m in \u001b[0;36mget_from_cache\u001b[1;34m(url, cache_dir, force_download, proxies, etag_timeout, resume_download, user_agent, local_files_only)\u001b[0m\n\u001b[0;32m    748\u001b[0m             \u001b[0mlogger\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"%s not found in cache or force_download set to True, downloading to %s\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0murl\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtemp_file\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    749\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 750\u001b[1;33m             \u001b[0mhttp_get\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0murl\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtemp_file\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mproxies\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mproxies\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresume_size\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mresume_size\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0muser_agent\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0muser_agent\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    751\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    752\u001b[0m         \u001b[0mlogger\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"storing %s in cache at %s\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0murl\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcache_path\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\fewrel2\\lib\\site-packages\\transformers\\file_utils.py\u001b[0m in \u001b[0;36mhttp_get\u001b[1;34m(url, temp_file, proxies, resume_size, user_agent)\u001b[0m\n\u001b[0;32m    641\u001b[0m         \u001b[0mdisable\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mbool\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlogger\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgetEffectiveLevel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0mlogging\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mNOTSET\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    642\u001b[0m     )\n\u001b[1;32m--> 643\u001b[1;33m     \u001b[1;32mfor\u001b[0m \u001b[0mchunk\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mresponse\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0miter_content\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mchunk_size\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1024\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    644\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mchunk\u001b[0m\u001b[1;33m:\u001b[0m  \u001b[1;31m# filter out keep-alive new chunks\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    645\u001b[0m             \u001b[0mprogress\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mchunk\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\fewrel2\\lib\\site-packages\\requests\\models.py\u001b[0m in \u001b[0;36mgenerate\u001b[1;34m()\u001b[0m\n\u001b[0;32m    751\u001b[0m             \u001b[1;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mraw\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'stream'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    752\u001b[0m                 \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 753\u001b[1;33m                     \u001b[1;32mfor\u001b[0m \u001b[0mchunk\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mraw\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstream\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mchunk_size\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdecode_content\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    754\u001b[0m                         \u001b[1;32myield\u001b[0m \u001b[0mchunk\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    755\u001b[0m                 \u001b[1;32mexcept\u001b[0m \u001b[0mProtocolError\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\fewrel2\\lib\\site-packages\\urllib3\\response.py\u001b[0m in \u001b[0;36mstream\u001b[1;34m(self, amt, decode_content)\u001b[0m\n\u001b[0;32m    574\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    575\u001b[0m             \u001b[1;32mwhile\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mis_fp_closed\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_fp\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 576\u001b[1;33m                 \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mread\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mamt\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mamt\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdecode_content\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdecode_content\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    577\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    578\u001b[0m                 \u001b[1;32mif\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\fewrel2\\lib\\site-packages\\urllib3\\response.py\u001b[0m in \u001b[0;36mread\u001b[1;34m(self, amt, decode_content, cache_content)\u001b[0m\n\u001b[0;32m    517\u001b[0m             \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    518\u001b[0m                 \u001b[0mcache_content\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mFalse\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 519\u001b[1;33m                 \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_fp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mread\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mamt\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mfp_closed\u001b[0m \u001b[1;32melse\u001b[0m \u001b[1;34mb\"\"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    520\u001b[0m                 if (\n\u001b[0;32m    521\u001b[0m                     \u001b[0mamt\u001b[0m \u001b[1;33m!=\u001b[0m \u001b[1;36m0\u001b[0m \u001b[1;32mand\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\fewrel2\\lib\\http\\client.py\u001b[0m in \u001b[0;36mread\u001b[1;34m(self, amt)\u001b[0m\n\u001b[0;32m    453\u001b[0m             \u001b[1;31m# Amount is given, implement using readinto\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    454\u001b[0m             \u001b[0mb\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mbytearray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mamt\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 455\u001b[1;33m             \u001b[0mn\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mreadinto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mb\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    456\u001b[0m             \u001b[1;32mreturn\u001b[0m \u001b[0mmemoryview\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mb\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mn\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtobytes\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    457\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\fewrel2\\lib\\http\\client.py\u001b[0m in \u001b[0;36mreadinto\u001b[1;34m(self, b)\u001b[0m\n\u001b[0;32m    497\u001b[0m         \u001b[1;31m# connection, and the user is reading more bytes than will be provided\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    498\u001b[0m         \u001b[1;31m# (for example, reading in 1k chunks)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 499\u001b[1;33m         \u001b[0mn\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mreadinto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mb\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    500\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mn\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mb\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    501\u001b[0m             \u001b[1;31m# Ideally, we would raise IncompleteRead if the content-length\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\fewrel2\\lib\\socket.py\u001b[0m in \u001b[0;36mreadinto\u001b[1;34m(self, b)\u001b[0m\n\u001b[0;32m    667\u001b[0m         \u001b[1;32mwhile\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    668\u001b[0m             \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 669\u001b[1;33m                 \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_sock\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrecv_into\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mb\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    670\u001b[0m             \u001b[1;32mexcept\u001b[0m \u001b[0mtimeout\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    671\u001b[0m                 \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_timeout_occurred\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\fewrel2\\lib\\ssl.py\u001b[0m in \u001b[0;36mrecv_into\u001b[1;34m(self, buffer, nbytes, flags)\u001b[0m\n\u001b[0;32m   1239\u001b[0m                   \u001b[1;34m\"non-zero flags not allowed in calls to recv_into() on %s\"\u001b[0m \u001b[1;33m%\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1240\u001b[0m                   self.__class__)\n\u001b[1;32m-> 1241\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mread\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnbytes\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbuffer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1242\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1243\u001b[0m             \u001b[1;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrecv_into\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnbytes\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mflags\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\fewrel2\\lib\\ssl.py\u001b[0m in \u001b[0;36mread\u001b[1;34m(self, len, buffer)\u001b[0m\n\u001b[0;32m   1097\u001b[0m         \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1098\u001b[0m             \u001b[1;32mif\u001b[0m \u001b[0mbuffer\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1099\u001b[1;33m                 \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_sslobj\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mread\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbuffer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1100\u001b[0m             \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1101\u001b[0m                 \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_sslobj\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mread\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# Download the models\n",
    "# Documentation for GPT: https://huggingface.co/transformers/model_doc/gpt2.html#gpt2lmheadmodel\n",
    "gpt_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')\n",
    "gpt_model = GPT2LMHeadModel.from_pretrained('gpt2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2621d1bd-c2d1-400d-932e-715815280012",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a dataset class for language model fine-tuning, it's a generator so we don't have to store the entire dataset in memory\n",
    "class GPT2Dataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, txt_list, tokenizer, max_length=max_length):\n",
    "        # Encode all the text, padding and truncuating it along with adding attention masks to get the sequence length the same across all samples\n",
    "        encodings_dict = tokenizer.batch_encode_plus(txt_list, truncation=True, max_length=max_length, padding=\"max_length\")\n",
    "        self.input_ids = encodings_dict['input_ids']\n",
    "        self.attn_masks = encodings_dict['attention_mask']\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.input_ids)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return_dict = {\"input_ids\": torch.tensor(self.input_ids[idx]),\n",
    "                       \"attention_mask\": torch.tensor(self.attn_masks[idx]), \n",
    "                       \"labels\": torch.tensor(self.input_ids[idx])} \n",
    "        return return_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "12b29955-cfdc-464a-9253-0eee67524827",
   "metadata": {},
   "outputs": [],
   "source": [
    "gpt_tokenizer.pad_token = gpt_tokenizer.eos_token\n",
    "train_dataset = GPT2Dataset(X_train, gpt_tokenizer)\n",
    "val_dataset = GPT2Dataset(X_val, gpt_tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78855ca2-a24b-4452-b1ff-d9891f3c0e4c",
   "metadata": {},
   "source": [
    "Do fine tuning of the gpt_model using the hugging face out of the box trainer https://huggingface.co/transformers/custom_datasets.html#fine-tuning-with-trainerfrom"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "134989b8-2794-4412-9ea0-2a77d48b293d",
   "metadata": {},
   "outputs": [],
   "source": [
    "if os.path.exists(\"gpt_model.model\"):\n",
    "    gpt_model = torch.load(\"gpt_model.model\")\n",
    "else:\n",
    "    training_args = TrainingArguments(\n",
    "        output_dir='gpt_finetuning',     # output directory\n",
    "        num_train_epochs=1,              # total number of training epochs (1 is enough to get very low perplexity and perplexity increases at 2)\n",
    "        per_device_train_batch_size=4,  # batch size per device during training\n",
    "        per_device_eval_batch_size=4,   # batch size for evaluation\n",
    "        warmup_steps=500,                # number of warmup steps for learning rate scheduler\n",
    "        weight_decay=0.001,               # strength of weight decay\n",
    "        logging_dir='gpt_finetuning_logs',            # directory for storing logs\n",
    "        logging_steps=100,\n",
    "\n",
    "        # Increase the betas as the batch size is quite small so updates are very stochastic\n",
    "        adam_beta1=0.95, \n",
    "        adam_beta2=0.9995,\n",
    "        learning_rate=1e-5\n",
    "    )\n",
    "\n",
    "    trainer = Trainer(\n",
    "        model=gpt_model,                         # the instantiated 🤗 Transformers model to be trained\n",
    "        args=training_args,                  # training arguments, defined above\n",
    "        train_dataset=train_dataset,         # training dataset\n",
    "        eval_dataset=val_dataset             # evaluation dataset\n",
    "    )\n",
    "\n",
    "    trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "ab400612-98a3-4f41-b0a8-803963bff92e",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(gpt_model, \"gpt_model.model\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8f6c92a-403f-43e8-8312-52edd172e04c",
   "metadata": {},
   "source": [
    "# Testing Prompt Classification\n",
    "Perform classification by adding the prompt \"In summary the movie was\" to the end of a review and classifying based on the probability of \"Good\" v.s. \"Bad\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "4bbc1e49-7a01-48ed-9896-3bc2dfeb761c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a new model class, takes a list of strings as input and outputs the binary classification probability p(good | review) / (p(good | review) + p(bad | review))\n",
    "class gpt_prompt(torch.nn.Module):\n",
    "    def __init__(self, gpt_model, tokenizer, prompt=\"in summary the movie was\", pos_token=\"good\", neg_token=\"bad\"):\n",
    "        super().__init__()\n",
    "        self.gpt = gpt_model.to(device)\n",
    "        self.tokenizer = tokenizer\n",
    "        self.prompt = prompt\n",
    "        self.prompt_length = len(tokenizer.encode(prompt))\n",
    "        self.pos_token = tokenizer.encode(pos_token)[0]\n",
    "        self.neg_token = tokenizer.encode(neg_token)[0]\n",
    "        \n",
    "    def forward(self, sentences):\n",
    "        # Loop through and add prompts to the sentences, then perform prediction\n",
    "        preds = []\n",
    "        for string in sentences:\n",
    "            length = len(self.tokenizer.encode(string))\n",
    "            if length + self.prompt_length <= max_length:\n",
    "                encodings_dict = self.tokenizer.encode_plus(string + \" \" + self.prompt, truncation=True, max_length=max_length, padding=\"max_length\")\n",
    "                prompt_end_idx = length + self.prompt_length - 1\n",
    "            else: # encode full sentence then replace the (truncated) end with the prompt\n",
    "                encodings_dict = self.tokenizer.encode_plus(string, truncation=True, max_length=max_length, padding=\"max_length\")\n",
    "                encodings_dict[\"input_ids\"][-self.prompt_length:] = self.tokenizer.encode(self.prompt)\n",
    "                prompt_end_idx = max_length - 1\n",
    "            preds.append(self.predict(encodings_dict, prompt_end_idx))\n",
    "            # Tests\n",
    "            # print(self.tokenizer.decode(encodings_dict[\"input_ids\"]))\n",
    "            #print(prompt_end_idx, self.tokenizer.decode(encodings_dict[\"input_ids\"][prompt_end_idx]))\n",
    "            assert self.tokenizer.decode(encodings_dict[\"input_ids\"][prompt_end_idx]) != \"<|endoftext|>\", \"Prompt_end_idx points to a padding token not the prompt\"\n",
    "        return preds\n",
    "    \n",
    "    # given an encoded sentence + prompt and the end of the prompt return binary classification probability\n",
    "    def predict(self, encodings_dict, prompt_end_idx):\n",
    "        input_ids = torch.tensor([encodings_dict[\"input_ids\"]]).long().to(device)\n",
    "        attn_mask = torch.tensor([encodings_dict[\"attention_mask\"]]).long().to(device)\n",
    "        word_preds = self.gpt(input_ids, attention_mask=attn_mask)[\"logits\"][0, prompt_end_idx, :]\n",
    "        word_preds = torch.nn.functional.softmax(word_preds, dim=0)\n",
    "        return word_preds[self.pos_token]/(word_preds[self.pos_token] + word_preds[self.neg_token]), word_preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "c31ed1db-2953-474e-a246-98d48c3b1ecd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "prompted = gpt_prompt(gpt_model, gpt_tokenizer, prompt=\"All in all the movie was\", pos_token=\"good\", neg_token=\"bad\")\n",
    "prompted.eval()\n",
    "pass"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4db324b1-eae1-482c-9e3a-cc2df12ae506",
   "metadata": {},
   "source": [
    "## An example\n",
    "Here we view the model's prediction for the following positive review"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "d2ef3c22-0707-4b61-ad7e-981a4182a19d",
   "metadata": {},
   "outputs": [],
   "source": [
    "review = \"\"\"The unlikely duo of Zero Mostel and Harry Belafonte team up to give us some interesting performances and subject matter in The Angel Levine. \n",
    "It's one interesting twist on the themes from It's A Wonderful Life.<br /><br />Zero is married to Ida Kaminsky and the two of them belong to a special class of elderly Jewish poor in New York.\n",
    "Mostel used to be a tailor and proud of his trade, but his back and arthritis have prevented him from working. Kaminsky is mostly bedridden. He's reduced to applying for welfare. \n",
    "In desperation like Jimmy Stewart, he cries out to God for some help.<br /><br />Now maybe if he had gotten someone like Henry Travers things might have worked out differently, \n",
    "but even Stewart had trouble accepting Travers. But Travers had one thing going for him, he was over 100 years off this mortal coil and all his ties to earthly things were gone.\n",
    "God sent Mostel something quite different, the recently deceased Harry Belafonte who should have at least been given some basic training for angels before being given an assignment.<br /><br />Belafonte\n",
    "hasn't accepted he's moved on from life, he's still got a lot of issues. He also has a wife, Gloria Foster, who doesn't know he's passed on, hit by a car right at the beginning of the film. \n",
    "You put his issues and Mostel's issues and you've got a good conflict, starting with the fact that Mostel can't believe in a black Jew named Levine.<br /><br />This was the farewell performance for \n",
    "Polish/Jewish actress Ida Kaminsky who got a nomination for Best Actress in The Shop on Main Street a few years back. The other prominent role here is that of Irish actor Milo O'Shea playing a nice \n",
    "Jewish doctor. Remembering O'Shea's brogue from The Verdict, I was really surprised to see and hear him carry off the part of the doctor.<br /><br />The Angel Levine raises some interesting and \n",
    "disturbing questions about faith and race in this society. It's brought to you by a stellar cast and of course created by acclaimed writer Bernard Malamud. Make sure to catch it when broadcast.\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "7e30b75c-b4d8-4d47-ad11-80939c6c7f38",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "474"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(gpt_tokenizer.encode(review))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "ff9d8d40-b9db-4f0a-bbd4-2094eb5badae",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[' a', ' very', ' great', ' an', ' the', ' so', ' just', ' one', ' nice', ' really', ' pretty', ' like', ' interesting', ' good', ' quite', ' wonderful', ' about', ' not', ',', ' funny', ' also', ' kind', ' well', ' fun', ' his', ' brilliant', ' amazing', ' excellent', ' memorable', ' beautiful', ' hilarious', ' such', ' more', ' touching', ' rather', ' enjoyable', ' almost', ' that', ' all', ' entertaining', ' fascinating', ' impressive', ' another', ' probably', ' as', ' much', ' superb', ' truly', ' painful', ' lovely', ' perfect', ' too', ' remarkable', ' to', ' fantastic', ' in', ' this', ' extremely', ' amusing', ' going', ' full', ' worth', ' somewhat', ' what', ' especially', ' done', ' surreal', ' way', ' only', ' something', ' fairly', ' better', ' delightful', ' enough', ' awesome', ' made', ' her', ' simply', ' even', ' pure', ' moving', ' bad', ' hard', ' mostly', ' actually', ' some', ' supposed', ' fine', ' no', ' particularly', ' incredible', ' meant', ' exciting', ' incredibly', ' bitters', ' nothing', ' over', ' by', ' still', ' special']\n",
      "Positive likelihood: tensor(0.9133, device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    results = prompted.forward([review])\n",
    "    result = results[0][1].cpu().numpy()\n",
    "    r = list(np.argsort(result)[::-1][:100])\n",
    "    print([gpt_tokenizer.decode([x]) for x in r])\n",
    "    print(\"Positive likelihood:\", results[0][0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1de2cc71-85e6-4161-b3c0-2fcd28b8cdec",
   "metadata": {},
   "source": [
    "We can see the most likely words (near the start of the list)  are positive e.g. \"great\", \"excellent\", \"good\". Whereas negative words such as \"mediocre\" and \"boring\" are much lower"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "11c02887-73e5-4af1-b162-7c0deb9711fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4692123-0c81-4cd0-a400-3b31260eb981",
   "metadata": {},
   "source": [
    "## Accuracy evaluation and threshold tuning\n",
    "Here we find the optimal prediction threshold, as our positive token good has a higher base chance of appearing than bad which means with a threshold of 0.5 the model tends to predict all samples as positive"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "022d6efd-756d-4fcb-b3d8-3c85f77d47d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(data):\n",
    "    with torch.no_grad():\n",
    "        preds = prompted.forward(data)\n",
    "        preds = list(zip(*preds))[0]\n",
    "        preds = torch.tensor(preds).cpu().numpy()\n",
    "    return preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "99ae02d0-fc32-4cd2-b236-08b54f8eb83a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Token indices sequence length is longer than the specified maximum sequence length for this model (1076 > 1024). Running this sequence through the model will result in indexing errors\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test set accuracy before tuning: 0.6324\n"
     ]
    }
   ],
   "source": [
    "prompted.eval()\n",
    "print(\"Test set accuracy before tuning: {}\".format(accuracy_score(y_test, predict(X_test) > 0.5)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "0eda797e-fc60-407b-807e-460035cce628",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = predict(X_val)\n",
    "thresholds = np.arange(0, 1, 0.01)\n",
    "accuracies = []\n",
    "for t in thresholds:\n",
    "    binary_preds = preds > t\n",
    "    accuracies.append(accuracy_score(y_val, binary_preds))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "0bbd0246-4260-4c29-bab6-25095d3f4382",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAacAAAEWCAYAAADCeVhIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAA8iklEQVR4nO3dd3gc1dXA4d9RL5Ysyb3IFVdww8L0XmIgYDqGBGxCCSGQhBAIED4gQBJCCiVAqA41NuCAY8BATO+4994tuckqtmxZ/Xx/zMgeLyorW6tZ7Z73efRop5/ZnZ2zc++dO6KqGGOMMeEkxu8AjDHGmECWnIwxxoQdS07GGGPCjiUnY4wxYceSkzHGmLBjyckYY0zYafHkJCIPiMh2EdniDp8vIhtFZJeIjGjpeDxxHVAcInKviLwSytjq2W4PN9bYlt72wRKRF0TkAff18SKy/ADX85SI/F/zRlfvtj4VkWtaYDt735sDWLbeGEWkl4ioiMQdXISh4/0uNXZ8H+z3TkQWi8hJB7q8Cb1mT04isk5E9rgHVu3f4+60HsAtwGBV7ewu8lfgRlVto6pzD2K7KiKHHETozRJHS1HVDW6s1X7HcjBU9QtVHdDYfCIyXkS+DFj2elW9v7lj8usHh9mnOY/vuhK+qh6qqp8e7LqbSzOcvyJOqH5FnaOqH9YxvgdQoKrbPON6AotDFEdThEscrYqIxKlqld9xhBN7T0yk8eOYbrFiPRE5DZgOdHWvpiaKyC4gFpgvIqvd+bqKyH9EJF9E1orILzzriBWRO0VktYiUiMhsEckWkc/dWea76760ju3HiMhdIrJeRLaJyEsi0lZEEuuKo47lDxWR6SJSKCJbReROz+QEd30lbnFBjme52z3xLhGR8z3TxovIlyLyVxEpcvf3TM/03iLyubvshyLyhKfYY79iGrdI534R+cqd/38i0t6zrivdfS8Qkf9zr3BPq2dfX3CLzKa76/pMRHp6pquI/FxEVgIr3XE/FJF5IlIsIl+LyFDP/CNEZI67rteAJM+0k0Qk1zOcLSJvup9/gYg8LiKDgKeAo93Pt9gT5wOeZa8VkVXuZzRVRLoGxHy9iKx0Y3xCRKSOfR8N3Alc6m5rvmdyz7reX89ncbWIbAA+dsf/RESWup/tB7XvoTgedo/DnSKyUEQO82wnU0TedbfznYj09cR3jIjMFJEd7v9j6vkMY93jaruIrAHOrms+z/yD3GOo2D2Gz/VMe8F9v+qMKWA974nIjQHj5ovIBe7rR8UpPt8pzvf3+HrWE3h893aPwxIRmQ60D5j/DRHZ4r4vn4vIoe7464AfAbe5n+fb7vi9x78454BHRGST+/eIiCS6004SkVwRucX9vDaLyFUNvI/jRWSNG+daEfmRZ1p9x0Mw56++IvKxON+J7SLyqohkeKZ/73vjmXatu93ac9Dh7vj9rtZk/+L22v3+rThVMP8SkUwRecfdRpH7urtn+SwR+Zf7HhaJyBR3/CIROcczX7y7Dw1Xn6hqs/4B64DT6pl2EpAbME6BQ9zXMcBs4G4gAegDrAF+4E6/FVgIDAAEGAa0C1xPPdv+CbDKXWcb4E3g5briqGPZNGAzTpFkkjt8pDvtXqAMOAsnwf0J+Naz7MVAV3ffLgV2A13caeOBSuBad9mfAZsAcad/g1PcmAAcB+wEXnGn9XJjjnOHPwVWA/2BZHf4QXfaYGCXu44Ed52VDXxOLwAlwAlAIvAo8GXAezUdyHK3NQLYBhzp7sc49zhIdLe3HrgZiAcucrf9QOAx4S47H3gYSHXf6+M879WXdcRZu55TgO3A4e52/wF8HhDzO0AGzhV8PjC6nv2/t/Z99oxr6P2t/SxecuNOBsbgHG+DcEoo7gK+duf/Ac5xnoFzHA9i3zHxAlAAjHKXexWY5E7LAoqAK9xpl7nD7TwxXuO+vh5YBmS7y32C53gJ2Ld4N9Y73c/rFPfzH9BYTHWs60rgK8/wYKAYSHSHfwy0c9dzC7AFSAp83/n+8f0N8Hf3sz3Bje8Vz3Z+gvO9TAQeAebVdZzUdZ4C7gO+BToCHYCvgfs9x2eVO088zve8FMisY99Tcb6jte9bF+BQ93W9x0OQ569DgNPd/esAfA48EsT35mIgDzgC51g7BOhZ1zbZ//tUu99/dreZ7H5uFwIp7nv9BjDFs/y7wGtApvteneiOvw14zTPfGGBho7mksRma+ud+6LtwDsjav2sDT0R1fSg4J7cNAdPvAP7lvl4OjKlnu419uB8BN3iGB+CcJOMaWx7nJDC3gRPZhwFfxj0NxDGvdh9wTrirPNNS3Dg645xAq4AUz/RXaDg53eWZ9wbgfff13cDEgO1U0HBymuQZbgNUA9me9+oUz/R/4n6ZPeOWAyfinEj2Jlx32tfUnZyOxkkadZ1Ax9NwcnoeeCgg5kqglyfm4zzTXwdub+AzrSs51ff+1n4WfTzT3wOu9gzH4JzUeuKc/FcARwExdezTc57hs4Bl7usrgBkB838DjPfEWJucPgau98x3BvUnp+NxkkSMZ9xE4N7GYqpjXWk4P8B6usN/ACY08H0oAoYFvu+e9zSOfd+FVM9y/w78jDzTMtxl2wYeJ5551rEvOa0GzvJM+wGwznN87vG+bzg/xI6qY7upOOe7C4HkgGn1Hg+e47Pe81cd2zoP95xEw9+bD4Bf1rOOxpJTBe4Ph3qWHw4Uua+7ADXUnbS74vyYSHeHJwO3NbaPoSrWO09VMzx/zwa5XE+cYr/i2j+cX3Od3OnZOAfSgeiK8wu+1nqcA79T3bPvp7HtbvG8LgWSPMURV8q+4q5i4DD2L5LYu6yqlrov27jxFnrGAWxsJM7AONq4r7t6l3XXWdDIurzz7wIK3fXUFUtP4JaAzy3bnb8rkKfuUenyfg5e2cB6PbCy7f0+XzfmAqCbZ5763p9gNbZ84HvyqOf9KMT55dpNVT8GHgeeALaJyDMikh7EdgKPYdzhbnzffp95Hct9b15VrWlgvUG9d6pagvMLeqw76jKcKy0AROQ3bhHTDvd9aUtAEV098RWp6u6A+GrXGSsiD4pTfL4TJ/EQxHq96w88N3iP9YKAY7LO/XfjuxTnqnWzWww60J1c7/EQTIAi0klEJolInruPr7Bv/xr63hzMOTNfVcs8MaSIyNPiVA/sxLl6yxCnRWU2zvmqKHAlqroJ+Aq40C2KPBPPMVGfcLvPaSOwNiCxpanqWZ7pdZZ1B2ETzgFSq/bX2NYg4+rT1A26ZcrPAjfiFL1kAItwDsrGbAayRCTFMy67qTF41uUtG669RG/I3m2JSBucoqFNnuneZLMR+EPA55aiqhPdbXcT2a9+p0c929wI9JC6mztrHeO89vt8RSQVZx/zGlmuLo1tK5jlNgI/DXhPklX1awBVfUxVR+JcaffHKbJuTOAxDM57Wdc+bmb/46W+97x2vdki4j0f1LfeYEwELhORo3GKmD4B57YBnCKeS3B+YWcAO2j8+7AZpx4uNSC+WpfjFBWdhpPsernja9fbpGPHXfemeuZtkKp+oKqn41xJLMP5/kMjx0MQ/oizH0NUNR2neLR2/xr63jR0zizFKUWp1TlgeuD7dgtOidORbgwnuOPF3U6Wtx4swItuzBcD36hqo8dWuCWnGUCJWwmX7P4iOkxEjnCnPwfcLyL9xDFURGpPsltpOIFMBG4Wp2K1Dc6H/VqQv9LfAbqIyK/EqTxNE5Ejg1guFecDzgcQpyL1sAaXcKnqemAWcK+IJLhf9HMaWaw+k4FzxKlMT8ApPmnshHCWiBznzn8/Tj1afVduzwLXi8iR7ueSKiJni0gaTrFTFfALtyL0Apy6i7rMwDkRPeiuI0lEjnWnbQW6u/HUZSJwlYgMF6cy+4/Ad6q6rpH9rMtWoFfAybqpngLukH0V821F5GL39RHuexWPUwRWhlMk0phpQH8RuVxE4sSpOB+Mc3wGeh3nPe8uIpnA7Q2s9zucE9Vt7md0Es6xNimYHa0nzp449TSvea7I0nCOhXwgTkTuBtLrXsU+nu/C793vwnHs/11IA8pxrpRTcD57r2DODXeJSAdxGrncjXNl0iTu1c0YN4mW41Rv1O57vcdDkDGmuevbISLd2P/HTEPfm+eA34jISPe7eYjsa9w0D7jcPc+OximGb0gaThFnsYhkAffUTlDVzThFl0+K03AiXkRO8Cw7Bac++Jc4dbONClVyelv2v8/prWAWUueehh/ilGWuxangfg7n1xA4FaKvA//DqXh8HqeiDpwT7ovuZfMldax+AvAyzqXoWpwTwk1BxlWCUxl5Dk7xxkrg5CCWWwL8DecEvRUYgnN5G6wf4ZQnFwAP4FQ2ljdh+do4FuPs6yScg3gXTrl5Q+v6N87BVwiMxPnVU9/6Z+E06ngcpw5hFU4dEapaAVzgDhfiFHu8Wc96qnHe40OADUCuOz84dSiLgS0isr2OZT8E/g/4j7uPfdlXtNRUb7j/C0RkzoGsQFXfwqlMnuQWgSzCKc4A54T8LM57tR7n8/1LEOsswPl+3OIucxvwQ1X93vvhrv8DnIryOdTznrvrrcB538/E+c49CVypqssa3dG611fubu80nOOo1gfA+zj1betxvoONFVXXuhynTroQ57j0nuBecteXByzBadzg9Tww2D03TKlj3Q/gJL8FOA2u5rjjmioG+DXOVVchzsn+Z9Do8QCNn79+j3Ny34FTbLr382zoe6Oqb+DU+/0bp95nCk4pCDiJ4hycerIfudMa8gjO+XY7znv8fsD0K3DqeZfhnF9+5YlxD853szcNHIteta3CTCsgTjPsZap6T6MzN7yeNjgHZD9VXVvH9BdwGincdTDbMcaYWu6Vcn9VrfeHrle4FesZD7f4p68492iNxilXn3KA6zpHnArNVJym5AvZV3FsjDEh4xYDXg08E+wylpzCW2ec5sG7gMeAn+mBd600Bqe4YRPQDxirdtlsjAkxEbkWp/j2PVX9vLH59y5n5ydjjDHhxq6cjDHGhJ2w7T6/qdq3b6+9evXyOwxjjGlVZs+evV1VO/gdR6CISU69evVi1qxZfodhjDGtiog01HuIb6xYzxhjTNix5GSMMSbsWHIyxhgTdiw5GWOMCTuWnIwxxoQdS07GGGPCjiUnY4wxYSdi7nMyxpim2FZSxgeLt5IUF8MRvbLo2S6FGoUVW0uYtb6IlPhYTj+0E+lJ8X6HGpUsORljooaqMnX+JibPzuWrVdup8XQt2r5NIuVV1ZSU7Xv+aMJbMZw2qCOnD+7EkG4Z9G6fSmyMUF2j5JeUIwKd0pN82JPIZ8nJGBMVcotKuf0/C/ly1XZ6ZKVw48mHcM6wrigwc10hs9cVkRgfyxG9MsnpmUXB7nL+O28T7yzYxLSFWwBISYglMyWBrTvLqHIz27GHtOPSI3pwxuBOJMXH+riHkSVieiXPyclR677IGFOXSTM2cP87SwC48+xBXHZED2JiJKhlq6prWJW/i0V5O1mUt4Pi0gq6ZiTTNSOZgl0VvDF7I7lFe2jfJoHfnDGAS3Kyg153OBCR2aqa43ccgSw5GWMi2ivfrueuKYs4pm87/nzhULKzUpp1/TU1ylert/PYRyuZua6Iod3bcu+5h3J4j8xm3U6ohGtystZ6xpiI9fmKfO6ZupiTB3TgpZ+MavbEBBATIxzfrwOv//RoHh07nK07y7jon1/z+Yr8Zt9WNLHkZIyJSCu2lvDzV+fQr2Mb/nH54cTFhvZ0JyKMGd6ND399Iv07pXHTxLlsKCgN6TYjmSUnY0zEWZS3g6v+NZOkhFieH38EbRJbru1XWlI8T18xEoDrXp5FaUVVI0uYulhyMsZEjIqqGv7+v+WMeeIrKqtreH5cDt0ykls8jp7tUnnsshEs31rCbZMXECl1+y3JkpMxJiLM3VDEuY9/yWMfr2LM8K5Mv/lEhnbP8C2eE/t34DdnDOCdBZv5ZPk23+JorSw5GWNatZKySu7+7yIu+OfXFJVW8Py4HP5+yXDapvjfs8N1J/Sha9sknv5sjd+htDp2E64xptVatmUn4ybMYFtJOeOO7sUtZ/QnLYy6G4qPjeEnx/XmgXeXMn9jMcOyM/wOqdUI6ZWTiIwWkeUiskpEbq9jeg8R+URE5orIAhE5yzPtDne55SLyg1DGaYxpfQp2lXP1C869jVNuOJZ7zz00rBJTrUuPyCYtMY5nvrCrp6YIWXISkVjgCeBMYDBwmYgMDpjtLuB1VR0BjAWedJcd7A4fCowGnnTXZ4wxVFTV8LNX5rB9VznPXJET1lckaUnxXH5UD95buNmaljdBKK+cRgGrVHWNqlYAk4AxAfMokO6+bgtscl+PASaparmqrgVWueszxkQ5VeWeqYuYsa6Qhy4aGtaJqdZVx/QmNkaY8NVav0NpNUKZnLoBGz3Due44r3uBH4tILjANuKkJyyIi14nILBGZlZ9vd2MbE+l27Knk1skLmDhjIzec1Jcxw793WghLndsmce6wbrw2cyNFuyv8DqdV8Lu13mXAC6raHTgLeFlEgo5JVZ9R1RxVzenQoUPIgjTG+O+TZdv4wcOf89bcPG48+RB+c8YAv0NqkmtP6M2eymomztzgdyitQiiTUx6Q7Rnu7o7zuhp4HUBVvwGSgPZBLmuMiRJPfLKKq16YSdvkeN664Rh+84MBrarnb4CBndM5pm87Xv5mPZXVNX6HE/ZCmZxmAv1EpLeIJOA0cJgaMM8G4FQAERmEk5zy3fnGikiiiPQG+gEzQhirMSZM/fu7Dfzlg+WcN7wrU2861tcbaw/WVcf2ZvOOMj5YvMXvUMJeyJKTqlYBNwIfAEtxWuUtFpH7RORcd7ZbgGtFZD4wERivjsU4V1RLgPeBn6tqdahiNcaEp/cXbeauKQs5eUAH/nLxMBLjWnej3VMGdqRHVgr/+mqd36GEPXuekzEmLM1cV8iPnv2Ow7ql88o1R5KSEBl9Bjz/5Vruf2cJU28Mj6tAe56TMcYEqbi0gpv+PZdumclMGH9ExCQmgItzupOaEMsLdvXUIEtOxpiwoqr87q1FbN9VzmNjR5CRkuB3SM0qPSmei0Z25+0Fm9hWUuZ3OGHLkpMxJqz8Z04e7y7czK/P6M+Q7m39DickrjymF5XVytvzN/sdStiy5GSMCRvrC3Zzz38XcWTvLH56Ql+/wwmZvh3a0L9TG6YvsVZ79bHkZIwJC6rKbZMXEBMj/P3S4cS2svuYmur0wZ2Yua6I4lLrMaIulpyMMWFh8uxcvltbyJ1nDfLl6bUt7fTBnamuUXsQYT0sORljfFe4u4I/TltKTs9MLs3JbnyBCDC0W1s6piUyfclWv0MJS5acjDG+++O0pZSUVfGH84e0um6JDlRMjHDa4E58tjyf8irrYyCQJSdjjK++XVPA5Nm5XHtCHwZ0TvM7nBZ1+qBO7K6o5uvVBX6HEnYsORljfLN9Vzm/fm0ePbJS+MUp/fwOp8Ud3bcdKQmxfGhFe99jyckY44vK6hp+/uocCnZX8OSPDic5oXX3m3cgkuJjObF/Bz5cupWamsjoSq65WHIyxvjiT9OW8d3aQv50wRAO6xaZN9sG4/TBndi6s5yFeTv8DiWsWHIyxrS4KXPzmPDVWq46thcXHN7d73B8dfKAjgB8uWq7z5GEF0tOxpgWtal4D3dNWcSoXlncedYgv8PxXWZqAn3apzJvY7HfoYQVS07GmBajqtz51kKqa5S/XjyM+Fg7BQEMz85g3sZiIuURRs3BjgxjTIt5a24eny7P57bRA+jRLsXvcMLGsOwM8kvK2bzDeimvZcnJGNMitpWU8fu3lzCyZybjju7ldzhhZXh2BoAV7XlYcjLGhJyqctdbi9hTWc2fLxwaNb1ABGtglzQSYmOYb8lpL0tOxpiQe+mb9fxvyVZuPWMAh3Rs43c4YScxLpbBXdOZa8lpL0tOxpiQWpBbzAPvLuHUgR25+rjefocTtoZnZ7AwdwdV1TV+hxIWLDkZY0Jmx55Kfv7vOXRMS+Jvlwyz4rwGDM/OYE9lNavyd/kdSlgIaXISkdEislxEVonI7XVMf1hE5rl/K0Sk2DOt2jNtaijjNMY0P1XljjcXsLm4jMcuG0FGSoLfIYW1YbWNIjYU+xpHuIgL1YpFJBZ4AjgdyAVmishUVV1SO4+q3uyZ/yZghGcVe1R1eKjiM8aE1nuLtjBt4RZuGz2AkT0z/Q4n7PVql0Lb5Hjm5xYzdlQPv8PxXSivnEYBq1R1japWAJOAMQ3MfxkwMYTxGGNayI7SSu6ZuphDu6Zz3fF9/A6nVRARhmVnMNeunIDQJqduwEbPcK477ntEpCfQG/jYMzpJRGaJyLcicl7IojTGNLsH319K4e4K/nzhUOKsF4igDc/OYMXWEkorqvwOxXfhctSMBSarqvdxkD1VNQe4HHhERPoGLiQi17kJbFZ+fn5LxWqMacC3awqYOGMjVx/XO6p7Gz8Qw7PbUqOwMNd6KA9lcsoDsj3D3d1xdRlLQJGequa5/9cAn7J/fVTtPM+oao6q5nTo0KE5YjbGHISSskpu/88CemSlcPNp/f0Op9UZ1j0DgPm5xb7GEQ5CmZxmAv1EpLeIJOAkoO+1uhORgUAm8I1nXKaIJLqv2wPHAksClzXGhI+aGuXXr89nY9Ee/nLR0Kh8eODBatcmkS5tk1i2ucTvUHwXstZ6qlolIjcCHwCxwARVXSwi9wGzVLU2UY0FJun+3fEOAp4WkRqcBPqgt5WfMSb8/OPjVUxfspV7zhnMkX3a+R1OqzWgcxpLt1hyCllyAlDVacC0gHF3BwzfW8dyXwNDQhmbMab5fLR0Kw9/uIILDu/G+GN6+R1OqzawczpfrVpDZXVNVD9SJHr33BjTLJZs2smvJs3jsG7p/PH8IYhYLxAHY1CXNCqrlTX5u/0OxVeWnIwxB2xDQSlXTphBm6Q4nrkih6R4q2c6WAM6pwGwbMtOnyPxlyUnY8wByS8p54oJ31FVU8PLV4+ia0ay3yFFhD7t2xAfKyyL8nonS07GmCYrq6xm/L9msG1nORPGH8EhHdP8DiliJMTF0LdDG5ZttisnY4xpkoenr2Dxpp08fvkIDu9h/eY1t4Gd0+zKye8AjDGty/yNxTz7xRouG5XNqYM6+R1ORBrYJZ3NO8rYUVrpdyi+seRkjAlaeVU1t06eT8e0JO44a5Df4USsgdYowpKTMSZ4T3yymhVbd/HHCw4jPSne73Ai1qAu6QBRXbRnyckYE5TZ6wt58pNVXDCiG6cMtOK8UOqYlkhGSrxdORljTEM2FJRy7Uuz6Z6ZzN3nDPY7nIgnIlHfKMKSkzGmQTv2VHLVCzOorlEmjD/CHrfeQgZ2Tmf5lhJqarTxmSOQJSdjTL0qq2u44dXZbCgs5ekrRtKnQxu/Q4oaAzunUVpRzcaiUr9D8YUlJ2NMnWpqlNsmL+CrVQX88fwhHGU9jbeogW6jiKVR+vgMS07GmO9RVf4wbSlvzc3jN2f05+Kc7MYXMs2qf6c2iERvc3JLTsaY73nqszU8/+Vaxh/Ti5+ffIjf4USllIQ4urZNZu326Oyd3JKTMWY/L3+zjj+/v4xzh3Xl7h8Otkdg+KhnuxTWF1idkzEmyj3/5Vr+77+LOW1QR/568TBiYiwx+alnuxQ2FFpyMsZEsac/W8397yxh9KGdefJHI0mIs9OD33pkpVK4u4KSsujrY8+OPmMMz32xhj+9t4yzh3bhH5ePsMQUJnq2SwGIyqI9OwKNiXJvzsnlgXeXcuZhnXn00uHEx9ppIVz0yHKSUzQW7dlRaEwU+2T5Nm6bvIBj+rbjkbHDibPEFFbsyilERGS0iCwXkVUicnsd0x8WkXnu3woRKfZMGyciK92/caGM05hotCC3mBtemcPALmk8fcVIEuNi/Q7JBEhLiicrNYENhdHXnDwuVCsWkVjgCeB0IBeYKSJTVXVJ7TyqerNn/puAEe7rLOAeIAdQYLa7bFGo4jUmmpRVVvOr1+aRlZrAv8aPIs0efxG2emRFZ3PyUF45jQJWqeoaVa0AJgFjGpj/MmCi+/oHwHRVLXQT0nRgdAhjNSaqPPrRStbk7+ZPFwyhQ1qi3+GYBkTrvU6hTE7dgI2e4Vx33PeISE+gN/BxU5c1xjTNwtwdPPP5Gi7J6c4J/Tv4HY5pRM+sFDbv2ENFVY3fobSocKn9HAtMVtXqpiwkIteJyCwRmZWfnx+i0IyJHBVVNdw6eT7tUhP43dn2XKbWoEe7VGoU8or3+B1KiwplcsoDvL1FdnfH1WUs+4r0gl5WVZ9R1RxVzenQwX4BGtOYxz5aybItJfzh/CG0TbZ6ptZgX4u96GoUEcrkNBPoJyK9RSQBJwFNDZxJRAYCmcA3ntEfAGeISKaIZAJnuOOMMQfok2XbePyTVVw8sjunD7bHrLcWPaP0XqeQtdZT1SoRuREnqcQCE1R1sYjcB8xS1dpENRaYpKrqWbZQRO7HSXAA96lqYahiNSbSbSws5VevzWNQl3TuP+8wv8MxTdAhLZHk+NioaxTRaHISkXOAd1W1ybVxqjoNmBYw7u6A4XvrWXYCMKGp2zTG7K+sspqfvTqbGlWe+vHhJMXb/UytiYhEZXPyYIr1LgVWishDbhGcMaaVUFXufGshi/J28vdLhtOzXarfIZkD0KNdStTdiNtoclLVH+PcHLsaeEFEvnFbyaWFPDpjzEF5+MOVvDknj1+d1s/qmVqxHlnOozM8tR8RL6gGEaq6E5iMcyNtF+B8YI7bq4MxJgxNmrGBxz5aySU53fnlqf38DscchJ7tUiirrGFbSbnfobSYRpOTiJwrIm8BnwLxwChVPRMYBtwS2vCMMQfik+Xb+N2URZzQvwN/OH+IPc22lavtnTya6p2Caa13IfCwqn7uHamqpSJydWjCMsYcqKWbd3Ljq3MY0CmNJ390uD0CIwLU1hWuL9jNqN5ZPkfTMoJJTvcCm2sHRCQZ6KSq61T1o1AFZoxpum07y7j6hZm0SYpjwvgjaJMYsrtFTAvqlpFMjETXvU7B/KR6A/A2I692xxljwsieimqueWkWRaWVPD/uCDq3TfI7JNNMEuJi6NI2mdyi6OnCKJjkFOf2Kg6A+zohdCEZY5pKVbnljXkszNvBY5eN4LBubf0OyTSz9m0SKNhd0fiMESKY5JQvIufWDojIGGB76EIyxjTVk5+uZtrCLdw+eqA1GY9QWakJFEVRcgqmQPp64FUReRwQnEdZXBnSqIwxQft0+Tb++r/lnDusK9ed0MfvcEyIZKYmsGLrLr/DaDGNJidVXQ0cJSJt3OHoeXeMCXPrC3bzi4lzGdg5nT9fONSajEewrJQECu3KaX8icjZwKJBUe/Cr6n0hjMsY04jd5VVc99JsYmKEZ64YSXKC9ZkXybLaJLCnspo9FdVR8VkHcxPuUzj9692EU6x3MdAzxHEZYxqgqtw2eQErt5Xwj8tGkO3epGkiV1aK0w6tqDQ6rp6CaRBxjKpeCRSp6u+Bo4H+oQ3LGNOQpz9fw7sLN/Pb0QM5vp89aDMaZKY6ySlaivaCSU5l7v9SEekKVOL0r2eM8cHnK/J56P1lnD20izWAiCJZUZacgqlzeltEMoC/AHMABZ4NZVDGmLrNXFfIz1+dQ/9OafzlImsAEU1qk1O0FOs1mJxEJAb4SFWLgf+IyDtAkqruaIngjDH7fLYin5++PIuubZOZMP4IUhKsa6JoUlvnVLArOpJTg8V67tNvn/AMl1tiMqblvbdwM9e8OJM+7dvw+vVH0zUj2e+QTAtrmxxPjETPlVMwdU4ficiFYuUHxvhiYe4Obpo4l6HdM5h43VG0b5Pod0jGBzExQmYU3esUTHL6KU5Hr+UislNESkRkZ4jjMsYApRVV/HLSXDqkJfL8uBzaJsf7HZLxUWZq9CSnYHqIsMexG+OT+99ZytqC3bx6zZFkpFh/y9EumnqJaDQ5icgJdY0PfPigMaZ5/W/xFibO2MBPT+zDMX3b+x2OCQNZqQms2R4dPcgF09znVs/rJGAUMBs4pbEFRWQ08CgQCzynqg/WMc8lOA80VGC+ql7ujq8GFrqzbVDVcwOXNSZSbSsp47f/WcBh3dK55fQBfodjwkRmagKF6+3KCQBVPcc7LCLZwCONLScisTgt/U4HcoGZIjJVVZd45ukH3AEcq6pFItLRs4o9qjo8mJ0wJpKoKrf/ZyGlFdU8culwEuLsMevGkZUaT1FpJTU1SkxMZLdRO5CjPhcYFMR8o4BVqrrGfUDhJGBMwDzXAk+oahGAqm47gHiMiSiTZm7k42XbuP3MgRzS0ap8zT5ZqYlU1yglZVV+hxJywdQ5/QOnyA2cZDYcp6eIxnTDefZTrVzgyIB5+rvb+Aqn6O9eVX3fnZYkIrOAKuBBVZ1SR2zXAdcB9OjRI4iQjAlv6wt2c/87Szj2kHaMO7qX3+GYMJOV6rTWLCytoG1KZLfcDKbOaZbndRUwUVW/asbt9wNOAroDn4vIELdHip6qmicifYCPRWSh+2ypvVT1GeAZgJycHMWYVqy6Rrnl9fnExgh/uWhYxBfbmKbLTKntX6+c3u1TfY4mtIJJTpOBMlWtBqcuSURSVLW0keXygGzPcHd3nFcu8J2qVgJrRWQFTrKaqap5AKq6RkQ+BUYAqzEmQj3z+RpmrS/i4UuHWQ8Qpk7tUp0bsAt3V/ocSegF1UME4P2mJAMfBrHcTKCfiPQWkQRgLDA1YJ4pOFdNiEh7nGK+NSKSKSKJnvHHAkswJkIt2bSTv09fzllDOnPe8G5+h2PCVKZbrFcUBfc6BXPllOR9NLuq7hKRRp9spqpVInIj8AFOfdIEVV0sIvcBs1R1qjvtDBFZAlQDt6pqgYgcAzwtIjU4CfRBbys/YyJJeVU1v359HhkpCTxw3hDradzUq7Zn8gJLTgDsFpHDVXUOgIiMBPYEs3JVnQZMCxh3t+e1Ar92/7zzfA0MCWYbxrR2f//fCpZtKWHC+Jy9Jx9j6pKSEEdSfExUdP4aTHL6FfCGiGzCeUx7Z5zHthtjDtLrMzfyzBdruGxUD04Z2MnvcEwrEC1dGAVzE+5MERkI1N6mvtxtwGCMOUA1Ncpf/7ecJz9dzfH92nPX2cHcOmhM9HT+2miDCBH5OZCqqotUdRHQRkRuCH1oxkSm8qpqfjFpLk9+uprLRmUzYfwRpCbagwNNcLIsOe11rXvfEQBubw7XhiwiYyLcn6Yt450Fm7n9zIH88fwhxMda90QmeFmpCVFR5xTMtyLW+6BBt888q7U15gB8uGQrL3y9jquO7cX1J/a1lnmmyTJTEiiMgke1B1OW8D7wmog87Q7/FHgvdCEZE5m27Cjj1snzGdwlndvPHOh3OKaVykpNoKS8ioqqmojuFDiY5PRbnP7rrneHF+C02DPGBKm6RvnVa3Mpr6rhH5ePIDEu1u+QTCtVe7tBcWkFHdOTfI4mdBpNu6paA3wHrMPpafwUYGlowzImsjz64Qq+XVPI7889lL4d2vgdjmnFouVG3HqvnESkP3CZ+7cdeA1AVU9umdCMiQzvL9rCYx+v4uKR3bloZHe/wzGtXG3nr5HehVFDxXrLgC+AH6rqKgARublFojImQqzaVsItr89jWPe23H/eYdYAwhy0dm3cnskjvMVeQ8V6FwCbgU9E5FkRORWnhwhjTBB2llVy3UuzSU6I5akrRpIUb/VM5uBFy5VTvclJVaeo6lhgIPAJTjdGHUXknyJyRgvFZ0yrde9/F7O+sJQnLj+cLm3tERimeWS4DxmM9DqnYBpE7FbVf6vqOTjPZJqL04LPGFOPT5Zt4825edxwUl+O7NPO73BMBImPjaFtcnz0XjnVRVWLVPUZVT01VAEZ09qVlFVy51sL6dexDTeecojf4ZgIlJWaQGFpZHdxGrl3cBnjkz+9t4ytO8t46KKhdj+TCYmMFLtyMsY0wdert/Pv7zZwzfF9GNEj0+9wTIRKT4qnpMyunIwxQSgpq+TWNxbQu30qN5/W3+9wTARLT45nZ1mV32GElPXTb0wz+f3bS9i8Yw//+dkxJCdYcZ4JnbbJcezcY1dOxphGfLB4C5Nn5/Lzkw+x4jwTculJ8ewsq0RV/Q4lZCw5GXOQtu8q5843F3JYt3R+cWo/v8MxUSA9OZ7KaqWsssbvUELGivWMOQhV1TXc/No8SsqrmHTJcHtwoGkR6UnOjbg7yyojtgg5pN8kERktIstFZJWI3F7PPJeIyBIRWSwi//aMHyciK92/caGM05gD9cC7S/li5XYeOO8w+nVK8zscEyXSk53rih0RXO8Usisn94m5TwCnA7nATBGZqqpLPPP0A+4AjlXVIhHp6I7PAu4BcgAFZrvLFoUqXmOaatKMDbzw9TquPq43l+Rk+x2OiSJ7r5wiODmF8sppFLBKVdeoagUwCRgTMM+1wBO1SUdVt7njfwBMV9VCd9p0YHQIYzWmSWasLeT//ruIE/p34A57qq1pYenJ+4r1IlUok1M3YKNnONcd59Uf6C8iX4nItyIyugnLIiLXicgsEZmVn5/fjKEbU78VW0u49qVZZGem8I/LRhBn9UymhaUnOYVeO/dE7r1Ofn+r4oB+wEk4DzV8VkQygl3Y7ecvR1VzOnToEJoIjfHILSrlyudnkBgXw4s/GUVb9xesMS3JrpwOTh7gLYjv7o7zygWmqmqlqq4FVuAkq2CWNaZFFewq58rnZ7C7oooXfzKK7KwUv0MyUSpt75WTJacDMRPoJyK9RSQBGAtMDZhnCs5VEyLSHqeYbw3wAXCGiGSKSCZwhjvOGF/sLq/iJy/MJK94D8+PO4JBXdL9DslEscS4WJLiYyK6C6OQtdZT1SoRuREnqcQCE1R1sYjcB8xS1ansS0JLgGrgVlUtABCR+3ESHMB9qloYqliNaUhFVQ3XvzKbRZt28tSPRzKqd5bfIRnj9BIRwVdOIb0JV1WnAdMCxt3tea3Ar92/wGUnABNCGZ8xjampUW55Yz5frNzOQxcN5fTBnfwOyRigtvPXyE1OfjeIMCasPfj+Mt6ev4nfjh5o9zKZsJKeFGet9YyJRvM3FvPsF2u4/MgeXH9iH7/DMWY/duVkTBSqqVHunrqYdqmJ3HHmQETE75CM2U/b5Miuc7LkZEwdJs/OZf7GYu48ayBpSXYvkwk/zmMzrFjPmKixo7SSP7+/jJyemZw/4nsdkxgTFtLdBw5G6jOd7JEZxgT42/TlFJVW8NKYUVacZ8JWelI8VTXKnspqUhIi71RuV07GeLz49Tpe+mY9447pxaFd2/odjjH12tuFUYS22LPkZIxr8uxc7pm6mDMGd+J3Zw3yOxxjGlT72IxIfaaTJSdjgPcWbua2yfM5vl97/nG59TRuwl/tAwcjtTm5fQNN1Ju3sZhfTprHiB6ZPH3FSBLjIvOx1yayRPoDBy05mai2raSM61+eTae2iTx3ZU5EViybyBTpj82wb6KJWhVVNfz81TkU76ngzZ8dS2Zqgt8hGRO0SH/goCUnE7X+8O4SZq4r4tGxwxnc1R6BYVqXNCvWMybyvDU3lxe/Wc81x/VmzHC70da0PglxMSTHx0ZssZ4lJxN1lm3ZyR1vLmRU7yxuP3Og3+EYc8CcXiIis1jPkpOJKjvLKvnZK3NIS4rncWsyblo5p3+9yLxysjonEzVUlVvfmM+GwlImXnsUHdOS/A7JmIMSyY/NsJ+NJmo8/OFKPli8lTvOHGiPWjcRIZIfOGjJyUSFt+bm8thHK7kkpztXH9fb73CMaRZt7crJmNZrxtpCfjt5IUf3accD5w2xnsZNxEiP4AcOWnIyES2veA8/fXkW3bOSeerHI0mIs0PeRI7aBw5G4jOdQvpNFZHRIrJcRFaJyO11TB8vIvkiMs/9u8Yzrdozfmoo4zSRqbpGufm1eVRWK8+PO4K2KfZEWxNZ0pPjqK5RSiuq/Q6l2YWstZ6IxAJPAKcDucBMEZmqqksCZn1NVW+sYxV7VHV4qOIzke+pz1YzY20hf7t4GL3bp/odjjHNbm/nr2WVpCZGVuPrUF45jQJWqeoaVa0AJgFjQrg9Y/aav7GYh6ev4IdDu3DB4dYDhIlMkfzAwVAmp27ARs9wrjsu0IUiskBEJotItmd8kojMEpFvReS8ujYgIte588zKz89vvshNq1a0u4JfvTaPjmmJ/MEaQJgIFskPHPS7dvhtoJeqDgWmAy96pvVU1RzgcuAREekbuLCqPqOqOaqa06FDh5aJ2IS1L1bmM/rRz8ktKuXvlw63eiYT0fY+cNCSU5PkAd4roe7uuL1UtUBVy93B54CRnml57v81wKfAiBDGalq58qpq7n9nCVc8P4O0pHjeuuFYjurTzu+wjAkpb51TpAllDdpMoJ+I9MZJSmNxroL2EpEuqrrZHTwXWOqOzwRKVbVcRNoDxwIPhTBW04qVVlTx05dn88XK7Yw7uie3nzmI5AR7mq2JfPvqnCw5BU1Vq0TkRuADIBaYoKqLReQ+YJaqTgV+ISLnAlVAITDeXXwQ8LSI1OBc3T1YRys/Y9ixp5KrX5jJnA1FPHTRUC7JyW58IWMiRFrtAwfLIq9BREjbHqrqNGBawLi7Pa/vAO6oY7mvgSGhjM20fgW7yrni+Rms3FbC45cfzllDuvgdkjEtKj42hpSEWLtyMiZc7NhTyRXPz2B1/i6evTKHkwZ09DskY3wRqY/NsORkWp3d5VVc9S/nium5cUdwYn9rqWmiV6Q+cNDvpuTGNElZZTXXvTyLeRuL+cdlIywxmagXqVdOlpxMq1FSVsnVL87kq1UF/OWiYYw+zOqYjMlMTSAC+321Yj3TOmwrKeOqf81k2ZYS/nbxMC4c2d3vkIwJC89emeN3CCFhycmEvbXbd3PlhO/YXlLBc+NyONkaPxgT8Sw5mbC2dPNOrnh+BjWqTLzuKIZnZ/gdkjGmBVhyMmFrzoYixk+YQUpCHK9ccxSHdGzjd0jGmBZiycmEpY+WbuWmiXPpkJbIK1cfSXZWit8hGWNakCUnE1a2lZRx/ztLeXv+JgZ2TuOln4yiY3qS32EZY1qYJScTFsqrqpn43Qb+Pn0FZZU13Hxaf64/qQ+JcdaBqzHRyJKT8VVFVQ1vzN7I4x+vYvOOMo7p2477zzuMvh2sfsmYaGbJyfjm61Xb+d2URazdvpvDe2Tw14uHcUzfdvbkWmOMJSfT8gp3V/DAu0t4c04ePdulMGG8c++SJSVjTC1LTqbFrNq2i5e+Wcfk2blUVNVw48mHcOMph5AUb/VKxpj9WXIyIVVRVcOHS7cyccYGvli5nYTYGM4Z1pXrT+xDv05pfodnjAlTlpxMs1NVFm/ayZS5ebw5N4/C3RV0bZvEr0/vz+VH9qB9m0S/QzTGhDlLTqZZlFVWM29jMd+sLuCdBZtYnb+b+Fjh1IGdGDsqm+P7dSA2xuqUjDHBseRkmkxVWb61hAUbd7Bo0w4W5O5g8aYdVFY7/faP6p3F1cf14czDOpOZmuBztMaY1siSkwlaTY3y4dKt/POz1czdUAxAm8Q4BndN55rj+3BEr0wO75FJRoolJGPMwbHkZBq1p6KaKfPymPDlWlZu20X3zGTuPWcwJw7oSM+sFGKsuM4Y08xCmpxEZDTwKBALPKeqDwZMHw/8BchzRz2uqs+508YBd7njH1DVF0MZq/m+Nfm7eGN2LhNnbKC4tJKBndN45NLh/HBoF+Ji7SHKxpjQCVlyEpFY4AngdCAXmCkiU1V1ScCsr6nqjQHLZgH3ADmAArPdZYtCFa9x6pJW5+/i0+X5TJ2/iQW5O4gR+MGhnRl/TC9G9c6yG2WNMS0ilFdOo4BVqroGQEQmAWOAwORUlx8A01W10F12OjAamBiiWKNSRVUNizbtYNa6QmauK2LWukKKSisBGNKtLXedPYhzhnWlk/UKboxpYaFMTt2AjZ7hXODIOua7UEROAFYAN6vqxnqW7Ra4oIhcB1wH0KNHj2YKu3VTVSqrlZ1llRTtrqBgdwVFuysoKq2kqLSCbTvLyCveQ15xGWvyd1FeVQNAz3YpnDqoE6N6ZXFknyx6tkv1eU+MMdHM7wYRbwMTVbVcRH4KvAicEuzCqvoM8AxATk6OhibE0KupUcqraiirrN47rqpG2eomks3FeyjcXUFhqZNkYkVIjo8lKT6GotJKNhXvYfOOMnbsqWRPZTXVNfW/FWlJcXTLSKZbRjLH9G1HTs9MRvbKpGOaXR0ZY8JHKJNTHpDtGe7OvoYPAKhqgWfwOeAhz7InBSz7abNHCBTsKuf4hz4Jat64GCElIY7khFjimthCrVqVsopq9lRW771aAah2E1NjRCAzJYGM5HhqVNlTWc2eimoyUhLompHEkb2zyEhJIDkhhuT4WNKT48lMSSArNYGMlHjapSaSkRJv/dgZY1qFUCanmUA/EemNk2zGApd7ZxCRLqq62R08F1jqvv4A+KOIZLrDZwB3hCLIpPhYfnRk40WCqs7VzJ6Kakorq6muaTyheIkIKfGxJCfEkhAbQ227AhEhKT6W5PhYEuNiqM15MTFCx7REumYk06VtMlmpCdbDgjEmaoQsOalqlYjciJNoYoEJqrpYRO4DZqnqVOAXInIuUAUUAuPdZQtF5H6cBAdwX23jiOaWmhjH784eHIpVG2OMOUCi2mqravaTk5Ojs2bN8jsMY4xpVURktqrm+B1HILuT0hhjTNix5GSMMSbsWHIyxhgTdiw5GWOMCTuWnIwxxoQdS07GGGPCjiUnY4wxYSdi7nMSkXxg/UGsoj2wvZnCaS2icZ8hOvc7GvcZonO/m7rPPVW1Q6iCOVARk5wOlojMCscb0UIpGvcZonO/o3GfITr3O1L22Yr1jDHGhB1LTsYYY8KOJad9nvE7AB9E4z5DdO53NO4zROd+R8Q+W52TMcaYsGNXTsYYY8KOJSdjjDFhJ6qSk4iMFpHlIrJKRG6vY3qiiLzmTv9ORHr5EGazC2K/fy0iS0RkgYh8JCI9/YizOTW2z575LhQRFZFW3/QWgttvEbnE/bwXi8i/WzrG5hbE8d1DRD4RkbnuMX6WH3E2JxGZICLbRGRRPdNFRB5z35MFInJ4S8d40FQ1Kv5wnsa7GugDJADzgcEB89wAPOW+Hgu85nfcLbTfJwMp7uuftfb9Dmaf3fnSgM+Bb4Ecv+Nuoc+6HzAXyHSHO/oddwvs8zPAz9zXg4F1fsfdDPt9AnA4sKie6WcB7wECHAV853fMTf2LpiunUcAqVV2jqhXAJGBMwDxjgBfd15OBU0VEWjDGUGh0v1X1E1UtdQe/Bbq3cIzNLZjPGuB+4M9AWUsGF0LB7Pe1wBOqWgSgqttaOMbmFsw+K5Duvm4LbGrB+EJCVT8HChuYZQzwkjq+BTJEpEvLRNc8oik5dQM2eoZz3XF1zqOqVcAOoF2LRBc6wey319U4v7has0b32S3myFbVd1sysBAL5rPuD/QXka9E5FsRGd1i0YVGMPt8L/BjEckFpgE3tUxovmrq9z7sxPkdgAkfIvJjIAc40e9YQklEYoC/A+N9DsUPcThFeyfhXCF/LiJDVLXYz6BC7DLgBVX9m4gcDbwsIoepao3fgZn6RdOVUx6Q7Rnu7o6rcx4RicMpAihokehCJ5j9RkROA34HnKuq5S0UW6g0ts9pwGHApyKyDqdMfmoENIoI5rPOBaaqaqWqrgVW4CSr1iqYfb4aeB1AVb8BknA6R41kQX3vw1k0JaeZQD8R6S0iCTgNHqYGzDMVGOe+vgj4WN3axVas0f0WkRHA0ziJqbXXQUAj+6yqO1S1var2UtVeOPVs56rqLH/CbTbBHONTcK6aEJH2OMV8a1owxuYWzD5vAE4FEJFBOMkpv0WjbHlTgSvdVntHATtUdbPfQTVF1BTrqWqViNwIfIDTwmeCqi4WkfuAWao6FXge55J/FU5l41j/Im4eQe73X4A2wBtu+48Nqnqub0EfpCD3OeIEud8fAGeIyBKgGrhVVVtt6UCQ+3wL8KyI3IzTOGJ8a//RKSITcX5ktHfr0u4B4gFU9SmcurWzgFVAKXCVP5EeOOu+yBhjTNiJpmI9Y4wxrYQlJ2OMMWHHkpMxxpiwY8nJGGNM2LHkZIwxJuxYcjJRT0Taicg892+LiOS5r4vdJtfNvb17ReQ3TVxmVz3jXxCRi5onMmPChyUnE/VUtUBVh6vqcOAp4GH39XCg0S5u3N5EjDHNyJKTMQ2LFZFn3Wcf/U9EkgFE5FMReUREZgG/FJGRIvKZiMwWkQ9qe4AWkV94npU1ybPewe461ojIL2pHivNsrUXu368Cg3Hv+H/cfX7Rh0DH0O6+Mf6wX3zGNKwfcJmqXisirwMXAq+40xJUNUdE4oHPgDGqmi8ilwJ/AH4C3A70VtVyEcnwrHcgznO00oDlIvJPYCjOnfxH4jyH5zsR+UxV53qWOx8YgPNcok7AEmBCKHbcGD9ZcjKmYWtVdZ77ejbQyzPtNff/AJyOZKe73T/FArX9mC0AXhWRKTj92tV61+1gt1xEtuEkmuOAt1R1N4CIvAkcj/NwwFonABNVtRrYJCIfH/wuGhN+LDkZ0zBvD+3VQLJneLf7X4DFqnp0HcufjZNQzgF+JyJD6lmvfReN8bA6J2MO3nKgg/usIEQkXkQOdZ8bla2qnwC/xXkES5sG1vMFcJ6IpIhIKk4R3hcB83wOXCoisW691snNvTPGhAP7tWbMQVLVCrc592Mi0hbne/UIzrOSXnHHCfCYqha7RX91rWeOiLwAzHBHPRdQ3wTwFnAKTl3TBuCbZt4dY8KC9UpujDEm7FixnjHGmLBjyckYY0zYseRkjDEm7FhyMsYYE3YsORljjAk7lpyMMcaEHUtOxhhjws7/AyaZJkF0GY9+AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(thresholds, accuracies)\n",
    "plt.title(\"Effect of changing prediction threshold on validation set accuracy\")\n",
    "plt.ylabel(\"Accuracy\")\n",
    "plt.xlabel(\"Threshold\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "c10a50bb-8c66-42bc-9be1-a0214093cdf4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best threshold: 0.8300000000000001\n"
     ]
    }
   ],
   "source": [
    "best_threshold = thresholds[np.argmax(accuracies)]\n",
    "print(\"Best threshold: {}\".format(best_threshold))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "a755650e-3379-4ae1-8821-611e9006d63e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test set accuracy with threshold tuning (requiring 5000 labelled samples): 0.8224\n"
     ]
    }
   ],
   "source": [
    "prompted.eval()\n",
    "print(\"Test set accuracy with threshold tuning (requiring {} labelled samples): {}\".format(len(X_val), accuracy_score(y_test, predict(X_test) > best_threshold)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2633d2b7-1480-4a97-830f-6ad1b51bd6da",
   "metadata": {},
   "source": [
    "# Few Shot Learning\n",
    "Now we see the effect of training the prompted model using only a few labelled samples (e.g. 20), not the entire validation set to tune the threshold"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "df65c810-fcc7-4083-872d-02d020f868ca",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0 Train Loss: 0.6975820660591125 Validation Accuracy: 0.79\n",
      "Epoch 1 Train Loss: 0.002826210344210267 Validation Accuracy: 0.8\n",
      "Epoch 2 Train Loss: 0.004467617720365524 Validation Accuracy: 0.82\n",
      "Epoch 3 Train Loss: 0.0010386136127635837 Validation Accuracy: 0.84\n",
      "Epoch 4 Train Loss: 0.0007217149832285941 Validation Accuracy: 0.84\n",
      "Epoch 5 Train Loss: 0.0003348553436808288 Validation Accuracy: 0.84\n",
      "Epoch 6 Train Loss: 0.0005014603957533836 Validation Accuracy: 0.85\n",
      "Epoch 7 Train Loss: 0.0003245403349865228 Validation Accuracy: 0.85\n",
      "Epoch 8 Train Loss: 0.00043502970947884023 Validation Accuracy: 0.86\n",
      "Epoch 9 Train Loss: 0.00023206780315376818 Validation Accuracy: 0.85\n"
     ]
    }
   ],
   "source": [
    "# sample a few training examples\n",
    "n_samples = 20\n",
    "fine_tune_idx = np.random.choice(len(X_train), size=n_samples)\n",
    "fine_tune_x = X_train[fine_tune_idx]\n",
    "fine_tune_y = torch.tensor(y_train[fine_tune_idx]).float().to(device)\n",
    "\n",
    "optimizer = torch.optim.Adam(prompted.parameters(), lr=1e-5, betas=(0.95, 0.9995))\n",
    "loss_func = torch.nn.BCELoss()\n",
    "\n",
    "for epoch in range(10):  \n",
    "    prompted.train()\n",
    "    for i in range(len(fine_tune_x)):\n",
    "        prompted.zero_grad()\n",
    "        preds = prompted(fine_tune_x[i:i+1])\n",
    "        loss = loss_func(preds[0][0], fine_tune_y[i])\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "    prompted.eval()\n",
    "    accuracy = accuracy_score(y_val[:100], predict(X_val[:100]) > 0.5)\n",
    "    print(\"Epoch {} Train Loss: {} Validation Accuracy: {}\".format(epoch, loss, accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "d8808ca8-3c8e-4047-8b4e-3301e22098c1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test set accuracy with few shot learning (requiring 20 labelled samples): 0.8998\n"
     ]
    }
   ],
   "source": [
    "prompted.eval()\n",
    "print(\"Test set accuracy with few shot learning (requiring {} labelled samples): {}\".format(n_samples, accuracy_score(y_test, predict(X_test) > 0.5)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6988b1c8-ae29-4ee0-b4c4-1730d14e40ff",
   "metadata": {},
   "source": [
    "# P-tuning\n",
    "Implement a method of continuous prompt tuning from (https://github.com/THUDM/P-tuning) and test its results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e3654cbe-1b33-41d4-b59b-ce157bf949a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import BertTokenizer, BertModel\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
    "bert = BertModel.from_pretrained('bert-base-uncased')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "cad35ba2-8f9d-4b3a-b445-75a6d42def8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Class Plan:\n",
    "# take in the length of the prompt\n",
    "# create an embedding layer for that many tokens\n",
    "# create an LSTM and MLP to process the prompt \n",
    "# How to feed the embedding to BERT\n",
    "    # Get the embedding layer using https://huggingface.co/transformers/v3.0.2/model_doc/bert.html#transformers.BertModel.get_input_embeddings \n",
    "    # Then directly pass embedding to BERT using self.bert(input_embeds=, attention_mask=)\n",
    "# add prompting logic to the tokenizer function\n",
    "    # it needs to first pass all tokens through LSTM and MLP\n",
    "    # then add them into the sentence as needed\n",
    "class BERTPrompt(torch.nn.Module):\n",
    "    def __init__(self, bert, tokenizer, max_length=512, prompt_length=5):\n",
    "        super().__init__()\n",
    "        self.hidden_size = 768\n",
    "        self.max_length = max_length\n",
    "        self.prompt_length = prompt_length\n",
    "        \n",
    "        self.bert = bert.cuda()\n",
    "        self.tokenizer = tokenizer\n",
    "        self.bert_embedding = self.bert.get_input_embeddings()\n",
    "        self.linear = torch.nn.Linear(self.hidden_size, 1).cuda()\n",
    "        self.act = torch.nn.Sigmoid() \n",
    "        \n",
    "        # p tuning modules\n",
    "        self.prompt_embedding = torch.nn.Embedding(prompt_length, self.hidden_size).cuda()\n",
    "        self.lstm_head = torch.nn.LSTM(input_size=self.hidden_size, hidden_size=int(self.hidden_size/2),\n",
    "            num_layers=2, bidirectional=True, batch_first=True).cuda() # takes (batch_size, sequence length, hidden_size)\n",
    "        self.mlp_head = torch.nn.Sequential(torch.nn.Linear(self.hidden_size, self.hidden_size),\n",
    "                                            torch.nn.ReLU(),\n",
    "                                            torch.nn.Linear(self.hidden_size, self.hidden_size)).cuda()\n",
    "        \n",
    "    # calls the tokenize function to get input embeddings, then passes them through bert\n",
    "    def forward(self, input_sents):\n",
    "        embeds = torch.zeros(len(input_sents), self.max_length, self.hidden_size).cuda()\n",
    "        att_mask = torch.zeros(len(input_sents), self.max_length).cuda().long()\n",
    "        for i, sent in enumerate(input_sents):\n",
    "            embeds[i, :, :], att_mask[i, :] = self.tokenize(sent)\n",
    "        \n",
    "        _, cls_out = self.bert(inputs_embeds=embeds, attention_mask=att_mask)\n",
    "        logits = self.linear(cls_out)\n",
    "        return self.act(logits)\n",
    "    \n",
    "    # a tokenize function that embeds the sentence adds the prompt to the end and returns token embeddings along with attention mask\n",
    "    def tokenize(self, input_sent):\n",
    "        # generate prompt tokens from embedding\n",
    "        prompt_tokens = self.prompt_embedding(torch.arange(self.prompt_length).cuda())\n",
    "        prompt_tokens = torch.unsqueeze(prompt_tokens, 0) # add a batch dimension\n",
    "        prompt_tokens, _ = self.lstm_head(prompt_tokens)\n",
    "        prompt_tokens = self.mlp_head(prompt_tokens)[0]\n",
    "        \n",
    "        # Encode the input_sentence\n",
    "        encoding_dict = self.tokenizer.encode_plus(input_sent, truncation=True, max_length=self.max_length, padding=\"max_length\")\n",
    "        sep_pos = encoding_dict[\"input_ids\"].index(102) # the location of the [SEP] token is at the end of the input\n",
    "        token_embeds = self.bert_embedding(torch.tensor(encoding_dict[\"input_ids\"]).cuda())\n",
    "        \n",
    "        # Add the prompt tokens to the end and modify the att_mask to include the prompt\n",
    "        start_prompt_pos = min(sep_pos, self.max_length - self.prompt_length - 1) # if the sentence is truncutated we must further truncate it to fit in the prompt\n",
    "        end_prompt_pos = start_prompt_pos + self.prompt_length + 1\n",
    "        token_embeds[start_prompt_pos:end_prompt_pos, :] = torch.cat([prompt_tokens, token_embeds[sep_pos:sep_pos+1, :]], dim=0)\n",
    "        att_mask = encoding_dict[\"attention_mask\"]\n",
    "        att_mask[start_prompt_pos:end_prompt_pos] = [1]*(self.prompt_length+1)\n",
    "        att_mask = torch.tensor(att_mask).cuda()\n",
    "        \n",
    "        return token_embeds, att_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "49d41233-1367-47c5-9045-62a229945d2b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.4832]], device='cuda:0', grad_fn=<SigmoidBackward>)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pBERT = BERTPrompt(bert, tokenizer, prompt_length=5)\n",
    "pBERT.forward([\"Hello there\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "307bdaff-fec8-4fc5-bf2d-a30a50a3508e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0 Train Loss: 1.2669013738632202 Validation Accuracy: 0.5\n",
      "Epoch 1 Train Loss: 1.0662970542907715 Validation Accuracy: 0.5\n",
      "Epoch 2 Train Loss: 0.9416559934616089 Validation Accuracy: 0.5\n",
      "Epoch 3 Train Loss: 0.640381932258606 Validation Accuracy: 0.5\n",
      "Epoch 4 Train Loss: 0.5026525855064392 Validation Accuracy: 0.51\n",
      "Epoch 5 Train Loss: 0.2008924037218094 Validation Accuracy: 0.49\n",
      "Epoch 6 Train Loss: 0.19231173396110535 Validation Accuracy: 0.49\n",
      "Epoch 7 Train Loss: 0.19236771762371063 Validation Accuracy: 0.49\n",
      "Epoch 8 Train Loss: 0.14558613300323486 Validation Accuracy: 0.53\n",
      "Epoch 9 Train Loss: 0.07228980958461761 Validation Accuracy: 0.49\n"
     ]
    }
   ],
   "source": [
    "# sample a few training examples\n",
    "n_samples = 1000\n",
    "fine_tune_idx = np.random.choice(len(X_train), size=n_samples)\n",
    "fine_tune_x = X_train[fine_tune_idx]\n",
    "fine_tune_y = torch.tensor(y_train[fine_tune_idx]).float().to(device)\n",
    "\n",
    "optimizer = torch.optim.Adam(pBERT.parameters(), lr=1e-5, betas=(0.95, 0.9995))\n",
    "loss_func = torch.nn.BCELoss()\n",
    "\n",
    "for epoch in range(10):  \n",
    "    pBERT.train()\n",
    "    for i in range(len(fine_tune_x)):\n",
    "        pBERT.zero_grad()\n",
    "        preds = pBERT(fine_tune_x[i:i+1])\n",
    "        loss = loss_func(preds[0][0], fine_tune_y[i])\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "    pBERT.eval()\n",
    "    with torch.no_grad():\n",
    "        accuracy = accuracy_score(y_val[:100], pBERT.forward(X_val[:100]).cpu().numpy() > 0.5)\n",
    "    print(\"Epoch {} Train Loss: {} Validation Accuracy: {}\".format(epoch, loss, accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "812d6ab5-46a3-488c-93ad-8ac0f51a8c65",
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "cuDNN error: CUDNN_STATUS_EXECUTION_FAILED",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-16-3b31d7b9d4df>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mpBERT\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX_val\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;36m3\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32m<ipython-input-9-ed758ff21e77>\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input_sents)\u001b[0m\n\u001b[0;32m     35\u001b[0m         \u001b[0matt_mask\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput_sents\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmax_length\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     36\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msent\u001b[0m \u001b[1;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput_sents\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 37\u001b[1;33m             \u001b[0membeds\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m:\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0matt_mask\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m:\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtokenize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msent\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     38\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     39\u001b[0m         \u001b[0m_\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcls_out\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbert\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minputs_embeds\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0membeds\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mattention_mask\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0matt_mask\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-9-ed758ff21e77>\u001b[0m in \u001b[0;36mtokenize\u001b[1;34m(self, input_sent)\u001b[0m\n\u001b[0;32m     46\u001b[0m         \u001b[0mprompt_tokens\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mprompt_embedding\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mprompt_length\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     47\u001b[0m         \u001b[0mprompt_tokens\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mprompt_tokens\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# add a batch dimension\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 48\u001b[1;33m         \u001b[0mprompt_tokens\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0m_\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlstm_head\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mprompt_tokens\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     49\u001b[0m         \u001b[0mprompt_tokens\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmlp_head\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mprompt_tokens\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     50\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\fewrel2\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m    725\u001b[0m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    726\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 727\u001b[1;33m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    728\u001b[0m         for hook in itertools.chain(\n\u001b[0;32m    729\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\fewrel2\\lib\\site-packages\\torch\\nn\\modules\\rnn.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input, hx)\u001b[0m\n\u001b[0;32m    579\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcheck_forward_args\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch_sizes\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    580\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mbatch_sizes\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 581\u001b[1;33m             result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,\n\u001b[0m\u001b[0;32m    582\u001b[0m                               self.dropout, self.training, self.bidirectional, self.batch_first)\n\u001b[0;32m    583\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mRuntimeError\u001b[0m: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED"
     ]
    }
   ],
   "source": [
    "pBERT.forward(X_val[:3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5646018-4a4c-43d6-859d-810b63451587",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
